// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

#include "exprs/agg-fn.h"

#include "codegen/llvm-codegen.h"
#include "exprs/anyval-util.h"
#include "exprs/scalar-expr.h"
#include "runtime/descriptors.h"
#include "runtime/lib-cache.h"

#include "common/names.h"

using namespace impala_udf;

namespace impala {

AggFn::AggFn(const TExprNode& tnode, const SlotDescriptor& intermediate_slot_desc,
    const SlotDescriptor& output_slot_desc)
  : Expr(tnode),
    is_merge_(tnode.agg_expr.is_merge_agg),
    intermediate_slot_desc_(intermediate_slot_desc),
    output_slot_desc_(output_slot_desc),
    arg_type_descs_(AnyValUtil::ColumnTypesToTypeDescs(
        ColumnType::FromThrift(tnode.agg_expr.arg_types))) {
  DCHECK(tnode.__isset.fn);
  DCHECK(tnode.fn.__isset.aggregate_fn);
  DCHECK_EQ(tnode.node_type, TExprNodeType::AGGREGATE_EXPR);
  DCHECK_EQ(ColumnType::FromThrift(tnode.type).type,
      ColumnType::FromThrift(fn_.ret_type).type);
  const string& fn_name = fn_.name.function_name;
  if (fn_name == "count") {
    agg_op_ = COUNT;
  } else if (fn_name == "min") {
    agg_op_ = MIN;
  } else if (fn_name == "max") {
    agg_op_ = MAX;
  } else if (fn_name == "sum" || fn_name == "sum_init_zero") {
    agg_op_ = SUM;
  } else if (fn_name == "avg") {
    agg_op_ = AVG;
  } else if (fn_name == "ndv" || fn_name == "ndv_no_finalize") {
    agg_op_ = NDV;
  } else {
    agg_op_ = OTHER;
  }
}

Status AggFn::Init(const RowDescriptor& row_desc, RuntimeState* state) {
  // Initialize all children (i.e. input exprs to this aggregate expr).
  for (ScalarExpr* input_expr : children()) {
    RETURN_IF_ERROR(input_expr->Init(row_desc, /*is_entry_point*/ false, state));
  }

  // Initialize the aggregate expressions' internals.
  const TAggregateFunction& aggregate_fn = fn_.aggregate_fn;
  DCHECK_EQ(intermediate_slot_desc_.type().type,
      ColumnType::FromThrift(aggregate_fn.intermediate_type).type);
  DCHECK_EQ(output_slot_desc_.type().type, ColumnType::FromThrift(fn_.ret_type).type);

  time_t mtime = fn_.last_modified_time;
  // Load the function pointers. Must have init() and update().
  if (aggregate_fn.init_fn_symbol.empty() ||
      aggregate_fn.update_fn_symbol.empty() ||
      (aggregate_fn.merge_fn_symbol.empty() && !aggregate_fn.is_analytic_only_fn)) {
    // This path is only for partially implemented builtins.
    DCHECK_EQ(fn_.binary_type, TFunctionBinaryType::BUILTIN);
    stringstream ss;
    ss << "Function " << fn_.name.function_name << " is not implemented.";
    return Status(ss.str());
  }

  RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(
      fn_.hdfs_location, aggregate_fn.init_fn_symbol, mtime, &init_fn_, &cache_entry_));
  RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
      aggregate_fn.update_fn_symbol, mtime, &update_fn_, &cache_entry_));

  // Merge() is not defined for purely analytic function.
  if (!aggregate_fn.is_analytic_only_fn) {
    RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
        aggregate_fn.merge_fn_symbol, mtime, &merge_fn_, &cache_entry_));
  }
  // Serialize(), GetValue(), Remove() and Finalize() are optional
  if (!aggregate_fn.serialize_fn_symbol.empty()) {
    RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
        aggregate_fn.serialize_fn_symbol, mtime, &serialize_fn_, &cache_entry_));
  }
  if (!aggregate_fn.get_value_fn_symbol.empty()) {
    RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
        aggregate_fn.get_value_fn_symbol, mtime, &get_value_fn_, &cache_entry_));
  }
  if (!aggregate_fn.remove_fn_symbol.empty()) {
    RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
        aggregate_fn.remove_fn_symbol, mtime, &remove_fn_, &cache_entry_));
  }
  if (!aggregate_fn.finalize_fn_symbol.empty()) {
    RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
        fn_.aggregate_fn.finalize_fn_symbol, mtime, &finalize_fn_, &cache_entry_));
  }
  return Status::OK();
}

Status AggFn::Create(const TExpr& texpr, const RowDescriptor& row_desc,
    const SlotDescriptor& intermediate_slot_desc, const SlotDescriptor& output_slot_desc,
    RuntimeState* state, AggFn** agg_fn) {
  *agg_fn = nullptr;
  ObjectPool* pool = state->obj_pool();
  const TExprNode& texpr_node = texpr.nodes[0];
  DCHECK_EQ(texpr_node.node_type, TExprNodeType::AGGREGATE_EXPR);
  if (!texpr_node.__isset.fn) {
    return Status("Function not set in thrift AGGREGATE_EXPR node");
  }
  AggFn* new_agg_fn =
      pool->Add(new AggFn(texpr_node, intermediate_slot_desc, output_slot_desc));
  RETURN_IF_ERROR(Expr::CreateTree(texpr, pool, new_agg_fn));
  Status status = new_agg_fn->Init(row_desc, state);
  if (UNLIKELY(!status.ok())) {
    new_agg_fn->Close();
    return status;
  }
  for (ScalarExpr* input_expr : new_agg_fn->children()) {
    int fn_ctx_idx = 0;
    input_expr->AssignFnCtxIdx(&fn_ctx_idx);
  }
  *agg_fn = new_agg_fn;
  return Status::OK();
}

FunctionContext::TypeDesc AggFn::GetIntermediateTypeDesc() const {
  return AnyValUtil::ColumnTypeToTypeDesc(intermediate_slot_desc_.type());
}

FunctionContext::TypeDesc AggFn::GetOutputTypeDesc() const {
  return AnyValUtil::ColumnTypeToTypeDesc(output_slot_desc_.type());
}

Status AggFn::CodegenUpdateOrMergeFunction(
    LlvmCodeGen* codegen, llvm::Function** uda_fn) {
  const string& symbol =
      is_merge_ ? fn_.aggregate_fn.merge_fn_symbol : fn_.aggregate_fn.update_fn_symbol;
  vector<ColumnType> fn_arg_types;
  for (ScalarExpr* input_expr : children()) {
    fn_arg_types.push_back(input_expr->type());
  }
  // The intermediate value is passed as the last argument.
  fn_arg_types.push_back(intermediate_type());
  RETURN_IF_ERROR(codegen->LoadFunction(fn_, symbol, nullptr, fn_arg_types,
      fn_arg_types.size(), false, uda_fn, &cache_entry_));

  // Inline constants into the function body (if there is an IR body).
  if (!(*uda_fn)->isDeclaration()) {
    // TODO: IMPALA-4785: we should also replace references to GetIntermediateType()
    // with constants.
    codegen->InlineConstFnAttrs(GetOutputTypeDesc(), arg_type_descs_, *uda_fn);
    *uda_fn = codegen->FinalizeFunction(*uda_fn);
    if (*uda_fn == nullptr) {
      return Status(TErrorCode::UDF_VERIFY_FAILED, symbol, fn_.hdfs_location);
    }
  }
  return Status::OK();
}

void AggFn::Close() {
  // This also closes all the input expressions.
  Expr::Close();
}

void AggFn::Close(const vector<AggFn*>& exprs) {
  for (AggFn* expr : exprs) expr->Close();
}

string AggFn::DebugString() const {
  stringstream out;
  out << "AggFn(op=" << agg_op_;
  for (ScalarExpr* input_expr : children()) {
    out << " " << input_expr->DebugString() << ")";
  }
  out << ")";
  return out.str();
}

string AggFn::DebugString(const vector<AggFn*>& agg_fns) {
  stringstream out;
  out << "[";
  for (int i = 0; i < agg_fns.size(); ++i) {
    out << (i == 0 ? "" : " ") << agg_fns[i]->DebugString();
  }
  out << "]";
  return out.str();
}

}
