// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

#include "exprs/scalar-expr.inline.h"

#include <sstream>
#include <thrift/protocol/TDebugProtocol.h>

#include "codegen/codegen-anyval.h"
#include "codegen/llvm-codegen.h"
#include "common/object-pool.h"
#include "common/status.h"
#include "exprs/case-expr.h"
#include "exprs/cast-format-expr.h"
#include "exprs/compound-predicates.h"
#include "exprs/conditional-functions.h"
#include "exprs/hive-udf-call.h"
#include "exprs/in-predicate.h"
#include "exprs/is-not-empty-predicate.h"
#include "exprs/is-null-predicate.h"
#include "exprs/kudu-partition-expr.h"
#include "exprs/like-predicate.h"
#include "exprs/literal.h"
#include "exprs/null-literal.h"
#include "exprs/operators.h"
#include "exprs/scalar-expr-evaluator.h"
#include "exprs/scalar-fn-call.h"
#include "exprs/slot-ref.h"
#include "exprs/string-functions.h"
#include "exprs/timestamp-functions.h"
#include "exprs/tuple-is-null-predicate.h"
#include "exprs/udf-builtins.h"
#include "exprs/utility-functions.h"
#include "exprs/valid-tuple-id.h"
#include "runtime/runtime-state.h"
#include "runtime/tuple-row.h"
#include "runtime/tuple.h"
#include "runtime/types.h"
#include "udf/udf-internal.h"
#include "udf/udf.h"

#include "gen-cpp/Exprs_types.h"
#include "gen-cpp/ImpalaService_types.h"

#include "common/names.h"

using namespace impala_udf;

namespace impala {

const char* ScalarExpr::LLVM_CLASS_NAME = "class.impala::ScalarExpr";

ScalarExpr::ScalarExpr(const ColumnType& type, bool is_constant)
  : Expr(type),
    is_constant_(is_constant) {
}

ScalarExpr::ScalarExpr(const TExprNode& node)
  : Expr(node),
    is_constant_(node.is_constant) {
  if (node.__isset.fn) fn_ = node.fn;
}

Status ScalarExpr::Create(const TExpr& texpr, const RowDescriptor& row_desc,
    RuntimeState* state, ObjectPool* pool, ScalarExpr** scalar_expr) {
  *scalar_expr = nullptr;
  ScalarExpr* root;
  RETURN_IF_ERROR(CreateNode(texpr.nodes[0], pool, &root));
  RETURN_IF_ERROR(Expr::CreateTree(texpr, pool, root));
  // Assume that the root is a potential entry point for interpreted callers.
  // This is not always true but would require some work to determine for
  // each of the callsites of Create().
  // TODO: fix this - reducing the number of entry points would reduce codegen overhead
  // somewhat.
  Status status = root->Init(row_desc, /*is_entry_point*/ true, state);
  if (UNLIKELY(!status.ok())) {
    root->Close();
    return status;
  }
  int fn_ctx_idx = 0;
  root->AssignFnCtxIdx(&fn_ctx_idx);
  *scalar_expr = root;
  return Status::OK();
}

Status ScalarExpr::Create(const vector<TExpr>& texprs, const RowDescriptor& row_desc,
    RuntimeState* state, ObjectPool* pool, vector<ScalarExpr*>* exprs) {
  exprs->clear();
  for (const TExpr& texpr: texprs) {
    ScalarExpr* expr;
    RETURN_IF_ERROR(Create(texpr, row_desc, state, pool, &expr));
    DCHECK(expr != nullptr);
    exprs->push_back(expr);
  }
  return Status::OK();
}

Status ScalarExpr::Create(const TExpr& texpr, const RowDescriptor& row_desc,
    RuntimeState* state, ScalarExpr** scalar_expr) {
  return ScalarExpr::Create(texpr, row_desc, state, state->obj_pool(), scalar_expr);
}

Status ScalarExpr::Create(const vector<TExpr>& texprs, const RowDescriptor& row_desc,
    RuntimeState* state, vector<ScalarExpr*>* exprs) {
  return ScalarExpr::Create(texprs, row_desc, state, state->obj_pool(), exprs);
}

void ScalarExpr::AssignFnCtxIdx(int* next_fn_ctx_idx) {
  fn_ctx_idx_start_ = *next_fn_ctx_idx;
  if (HasFnCtx()) {
    fn_ctx_idx_ = *next_fn_ctx_idx;
    ++(*next_fn_ctx_idx);
  }
  for (ScalarExpr* child : children()) child->AssignFnCtxIdx(next_fn_ctx_idx);
  fn_ctx_idx_end_ = *next_fn_ctx_idx;
}

Status ScalarExpr::CreateNode(
    const TExprNode& texpr_node, ObjectPool* pool, ScalarExpr** expr) {
  switch (texpr_node.node_type) {
    case TExprNodeType::BOOL_LITERAL:
    case TExprNodeType::FLOAT_LITERAL:
    case TExprNodeType::INT_LITERAL:
    case TExprNodeType::STRING_LITERAL:
    case TExprNodeType::DECIMAL_LITERAL:
    case TExprNodeType::TIMESTAMP_LITERAL:
    case TExprNodeType::DATE_LITERAL:
      *expr = pool->Add(new Literal(texpr_node));
      return Status::OK();
    case TExprNodeType::CASE_EXPR:
      if (!texpr_node.__isset.case_expr) {
        return Status("Case expression not set in thrift node");
      }
      *expr = pool->Add(new CaseExpr(texpr_node));
      return Status::OK();
    case TExprNodeType::COMPOUND_PRED:
      if (texpr_node.fn.name.function_name == "and") {
        *expr = pool->Add(new AndPredicate(texpr_node));
      } else if (texpr_node.fn.name.function_name == "or") {
        *expr = pool->Add(new OrPredicate(texpr_node));
      } else {
        DCHECK_EQ(texpr_node.fn.name.function_name, "not");
        *expr = pool->Add(new ScalarFnCall(texpr_node));
      }
      return Status::OK();
    case TExprNodeType::NULL_LITERAL:
      *expr = pool->Add(new NullLiteral(texpr_node));
      return Status::OK();
    case TExprNodeType::SLOT_REF:
      if (!texpr_node.__isset.slot_ref) {
        return Status("Slot reference not set in thrift node");
      }
      *expr = pool->Add(new SlotRef(texpr_node));
      return Status::OK();
    case TExprNodeType::TUPLE_IS_NULL_PRED:
      *expr = pool->Add(new TupleIsNullPredicate(texpr_node));
      return Status::OK();
    case TExprNodeType::FUNCTION_CALL:
      if (!texpr_node.__isset.fn) {
        return Status("Function not set in thrift node");
      }
      // Special-case functions that have their own Expr classes
      // TODO: is there a better way to do this?
      if (texpr_node.fn.name.function_name == "if") {
        *expr = pool->Add(new IfExpr(texpr_node));
      } else if (texpr_node.fn.name.function_name == "isnull" ||
                 texpr_node.fn.name.function_name == "ifnull" ||
                 texpr_node.fn.name.function_name == "nvl") {
        *expr = pool->Add(new IsNullExpr(texpr_node));
      } else if (texpr_node.fn.name.function_name == "coalesce") {
        *expr = pool->Add(new CoalesceExpr(texpr_node));
      } else if (texpr_node.fn.binary_type == TFunctionBinaryType::JAVA) {
        *expr = pool->Add(new HiveUdfCall(texpr_node));
      } else if (texpr_node.__isset.cast_expr &&
          !texpr_node.cast_expr.cast_format.empty()) {
        *expr = pool->Add(new CastFormatExpr(texpr_node));
      } else {
        *expr = pool->Add(new ScalarFnCall(texpr_node));
      }
      return Status::OK();
    case TExprNodeType::IS_NOT_EMPTY_PRED:
      *expr = pool->Add(new IsNotEmptyPredicate(texpr_node));
      return Status::OK();
    case TExprNodeType::KUDU_PARTITION_EXPR:
      *expr = pool->Add(new KuduPartitionExpr(texpr_node));
      return Status::OK();
    case TExprNodeType::VALID_TUPLE_ID_EXPR:
      *expr = pool->Add(new ValidTupleIdExpr(texpr_node));
      return Status::OK();
    default:
      *expr = nullptr;
      stringstream os;
      os << "Unknown expr node type: " << texpr_node.node_type;
      return Status(os.str());
  }
}

Status ScalarExpr::OpenEvaluator(FunctionContext::FunctionStateScope scope,
    RuntimeState* state, ScalarExprEvaluator* eval) const {
  for (int i = 0; i < children_.size(); ++i) {
    RETURN_IF_ERROR(children_[i]->OpenEvaluator(scope, state, eval));
  }
  return Status::OK();
}

void ScalarExpr::CloseEvaluator(FunctionContext::FunctionStateScope scope,
    RuntimeState* state, ScalarExprEvaluator* eval) const {
  for (ScalarExpr* child : children_) child->CloseEvaluator(scope, state, eval);
}

void ScalarExpr::Close() {
  Expr::Close();
}

void ScalarExpr::Close(const vector<ScalarExpr*>& exprs) {
  for (ScalarExpr* expr : exprs) expr->Close();
}

struct MemLayoutData {
  int expr_idx;
  int byte_size;
  bool variable_length;
  int alignment;

  // TODO: why put var-len at end?
  bool operator<(const MemLayoutData& rhs) const {
    // variable_len go at end
    if (this->variable_length && !rhs.variable_length) return false;
    if (!this->variable_length && rhs.variable_length) return true;
    return this->byte_size < rhs.byte_size;
  }
};

int ScalarExpr::ComputeResultsLayout(const vector<ScalarExpr*>& exprs,
    vector<int>* offsets, int* var_result_begin) {
  if (exprs.size() == 0) {
    *var_result_begin = -1;
    return 0;
  }


  vector<MemLayoutData> data;
  data.resize(exprs.size());

  // Collect all the byte sizes and sort them
  for (int i = 0; i < exprs.size(); ++i) {
    DCHECK(!exprs[i]->type().IsComplexType()) << "NYI";
    data[i].expr_idx = i;
    data[i].byte_size = exprs[i]->type().GetSlotSize();
    DCHECK_GT(data[i].byte_size, 0);
    data[i].variable_length = exprs[i]->type().IsVarLenStringType();

  }

  sort(data.begin(), data.end());

  int byte_offset = 0;
  offsets->resize(exprs.size());
  *var_result_begin = -1;

  for (int i = 0; i < data.size(); ++i) {

    (*offsets)[data[i].expr_idx] = byte_offset;
    if (data[i].variable_length && *var_result_begin == -1) {
      *var_result_begin = byte_offset;
    }
    DCHECK(!(i == 0 && byte_offset > 0)) << "first value should be at start of layout";
    byte_offset += data[i].byte_size;
  }

  return byte_offset;
}

Status ScalarExpr::Init(
    const RowDescriptor& row_desc, bool is_entry_point, RuntimeState* state) {
  DCHECK(type_.type != INVALID_TYPE);
  for (int i = 0; i < children_.size(); ++i) {
    RETURN_IF_ERROR(children_[i]->Init(row_desc, false, state));
  }
  // Add the expression to the list of expressions to codegen in the codegen phase.
  if (ShouldCodegen(state)) {
    // If the expression is not interpretable, we need an entry point to evaluate
    // the expression from interpreted code, e.g. GetConstValue().
    bool is_codegen_entry_point = is_entry_point || !IsInterpretable();
    state->AddScalarExprToCodegen(this, is_codegen_entry_point);
  }
  return Status::OK();
}

string ScalarExpr::DebugString() const {
  // TODO: implement partial debug string for member vars
  stringstream out;
  out << " type=" << type_.DebugString();
  if (!children_.empty()) {
    out << " children=" << DebugString(children_);
  }
  return out.str();
}

string ScalarExpr::DebugString(const vector<ScalarExpr*>& exprs) {
  stringstream out;
  out << "[";
  for (int i = 0; i < exprs.size(); ++i) {
    out << (i == 0 ? "" : " ") << exprs[i]->DebugString();
  }
  out << "]";
  return out.str();
}

bool ScalarExpr::ShouldCodegen(const RuntimeState* state) const {
  // Use the interpreted path and call the builtin without codegen if any of the
  // followings is true:
  // 1. The expression does not have an associated RuntimeState, e.g. is a partition
  //    key expression in a descriptor table.
  // 2. codegen is disabled by query option.
  // 3. there is an optimization hint to disable codegen and the expr can be interpreted.
  return state != nullptr && !state->CodegenDisabledByQueryOption()
      && !(state->CodegenHasDisableHint() && IsInterpretable());
}

int ScalarExpr::GetSlotIds(vector<SlotId>* slot_ids) const {
  int n = 0;
  for (int i = 0; i < children_.size(); ++i) {
    n += children_[i]->GetSlotIds(slot_ids);
  }
  return n;
}

llvm::Function* ScalarExpr::GetStaticGetValWrapper(
    ColumnType type, LlvmCodeGen* codegen) {
  switch (type.type) {
    case TYPE_BOOLEAN:
      return codegen->GetFunction(IRFunction::SCALAR_EXPR_GET_BOOLEAN_VAL, false);
    case TYPE_TINYINT:
      return codegen->GetFunction(IRFunction::SCALAR_EXPR_GET_TINYINT_VAL, false);
    case TYPE_SMALLINT:
      return codegen->GetFunction(IRFunction::SCALAR_EXPR_GET_SMALLINT_VAL, false);
    case TYPE_INT:
      return codegen->GetFunction(IRFunction::SCALAR_EXPR_GET_INT_VAL, false);
    case TYPE_BIGINT:
      return codegen->GetFunction(IRFunction::SCALAR_EXPR_GET_BIGINT_VAL, false);
    case TYPE_FLOAT:
      return codegen->GetFunction(IRFunction::SCALAR_EXPR_GET_FLOAT_VAL, false);
    case TYPE_DOUBLE:
      return codegen->GetFunction(IRFunction::SCALAR_EXPR_GET_DOUBLE_VAL, false);
    case TYPE_STRING:
    case TYPE_CHAR:
    case TYPE_VARCHAR:
      return codegen->GetFunction(IRFunction::SCALAR_EXPR_GET_STRING_VAL, false);
    case TYPE_TIMESTAMP:
      return codegen->GetFunction(IRFunction::SCALAR_EXPR_GET_TIMESTAMP_VAL, false);
    case TYPE_DECIMAL:
      return codegen->GetFunction(IRFunction::SCALAR_EXPR_GET_DECIMAL_VAL, false);
    case TYPE_DATE:
      return codegen->GetFunction(IRFunction::SCALAR_EXPR_GET_DATE_VAL, false);
    default:
      DCHECK(false) << "Invalid type: " << type.DebugString();
      return NULL;
  }
}

llvm::Function* ScalarExpr::CreateIrFunctionPrototype(
    const string& name, LlvmCodeGen* codegen, llvm::Value* (*args)[2]) {
  llvm::Type* return_type = CodegenAnyVal::GetLoweredType(codegen, type());
  LlvmCodeGen::FnPrototype prototype(codegen, name, return_type);
  prototype.AddArgument(
      LlvmCodeGen::NamedVariable(
          "eval", codegen->GetStructPtrType<ScalarExprEvaluator>()));
  prototype.AddArgument(LlvmCodeGen::NamedVariable(
      "row", codegen->GetStructPtrType<TupleRow>()));
  llvm::Function* function = prototype.GeneratePrototype(NULL, args[0]);
  DCHECK(function != NULL);
  return function;
}

Status ScalarExpr::GetCodegendComputeFn(
    LlvmCodeGen* codegen, bool is_codegen_entry_point, llvm::Function** fn) {
  if (ir_compute_fn_ != nullptr) {
    *fn = ir_compute_fn_;
  } else {
    RETURN_IF_ERROR(GetCodegendComputeFnImpl(codegen, fn));
    ir_compute_fn_ = *fn;
  }
  if (is_codegen_entry_point && !added_to_jit_) {
    // Ensure Get*Val() is made callable if this function is called at least once
    // with is_codegen_entry_point=true.
    added_to_jit_ = true;
    codegen->AddFunctionToJit(*fn, &codegend_compute_fn_);
  }
  return Status::OK();
}

Status ScalarExpr::GetCodegendComputeFnWrapper(
    LlvmCodeGen* codegen, llvm::Function** fn) {
  for (ScalarExpr* expr : children_) {
    llvm::Function* dummy;
    // The codegen'd function will call expr->Get*Val(). Ensure that the child expr
    // is a codegen entry point we expr->GetVal() uses the fast codegen'd path.
    RETURN_IF_ERROR(expr->GetCodegendComputeFn(codegen, true, &dummy));
  }

  llvm::Function* static_getval_fn = GetStaticGetValWrapper(type(), codegen);

  // Call it passing this as the additional first argument.
  llvm::Value* args[2];
  *fn = CreateIrFunctionPrototype("CodegenComputeFnWrapper", codegen, &args);
  llvm::BasicBlock* entry_block =
      llvm::BasicBlock::Create(codegen->context(), "entry", *fn);
  LlvmBuilder builder(entry_block);
  llvm::Value* this_ptr =
      codegen->CastPtrToLlvmPtr(codegen->GetStructPtrType<ScalarExpr>(), this);
  llvm::Value* compute_fn_args[] = {this_ptr, args[0], args[1]};
  llvm::Value* ret = CodegenAnyVal::CreateCall(
      codegen, &builder, static_getval_fn, compute_fn_args, "ret");
  builder.CreateRet(ret);
  *fn = codegen->FinalizeFunction(*fn);
  if (UNLIKELY(*fn == nullptr)) {
    return Status(TErrorCode::IR_VERIFY_FAILED, "CodegendComputeFnWrapper");
  }
  return Status::OK();
}

#define SCALAR_EXPR_GET_VAL_INTERPRETED(type)                 \
  type ScalarExpr::Get##type##Interpreted(                    \
      ScalarExprEvaluator* eval, const TupleRow* row) const { \
    DCHECK(false) << DebugString();                           \
    return type::null();                                      \
  }

// At least one of these should always be overridden.
SCALAR_EXPR_GET_VAL_INTERPRETED(BooleanVal);
SCALAR_EXPR_GET_VAL_INTERPRETED(TinyIntVal);
SCALAR_EXPR_GET_VAL_INTERPRETED(SmallIntVal);
SCALAR_EXPR_GET_VAL_INTERPRETED(IntVal);
SCALAR_EXPR_GET_VAL_INTERPRETED(BigIntVal);
SCALAR_EXPR_GET_VAL_INTERPRETED(FloatVal);
SCALAR_EXPR_GET_VAL_INTERPRETED(DoubleVal);
SCALAR_EXPR_GET_VAL_INTERPRETED(StringVal);
SCALAR_EXPR_GET_VAL_INTERPRETED(TimestampVal);
SCALAR_EXPR_GET_VAL_INTERPRETED(DecimalVal);
SCALAR_EXPR_GET_VAL_INTERPRETED(DateVal);
SCALAR_EXPR_GET_VAL_INTERPRETED(CollectionVal);

string ScalarExpr::DebugString(const string& expr_name) const {
  stringstream out;
  out << expr_name << "(" << ScalarExpr::DebugString() << ")";
  return out.str();
}

}
