// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

// The following is cross-compiled to native code and IR, and used in the test below
#include "exprs/decimal-operators.h"
#include "exprs/scalar-expr.h"
#include "udf/udf.h"

#ifdef IR_COMPILE
#include "exprs/decimal-operators-ir.cc"
#endif

using namespace impala;
using namespace impala_udf;

// TestGetTypeAttrs() fills in the following constants
struct FnAttr {
  int return_type_size;
  int arg0_type_size;
  int arg1_type_size;
  int arg2_type_size;
};

DecimalVal TestGetFnAttrs(
    FunctionContext* ctx, const DecimalVal& arg0, BooleanVal& arg1, StringVal& arg2) {
  FnAttr* state = reinterpret_cast<FnAttr*>(
      ctx->GetFunctionState(FunctionContext::THREAD_LOCAL));
  state->return_type_size =
      ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_SIZE);
  state->arg0_type_size =
      ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 0);
  state->arg1_type_size =
      ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 1);
  state->arg2_type_size =
      ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 2);
  // This function and its callees call FunctionContextImpl::GetConstFnAttr();
  return DecimalOperators::CastToDecimalVal(ctx, arg0);
}

// Don't compile the actual test to IR
#ifndef IR_COMPILE

#include "testutil/gtest-util.h"
#include "codegen/codegen-util.h"
#include "codegen/llvm-codegen.h"
#include "common/init.h"
#include "exprs/anyval-util.h"
#include "exprs/scalar-expr.h"
#include "exprs/scalar-expr-evaluator.h"
#include "runtime/exec-env.h"
#include "runtime/mem-tracker.h"
#include "runtime/runtime-state.h"
#include "runtime/test-env.h"
#include "service/fe-support.h"
#include "udf/udf-internal.h"
#include "udf/udf-test-harness.h"

#include "gen-cpp/Exprs_types.h"

#include "common/names.h"

namespace impala {

const char* TEST_GET_FN_ATTR_SYMBOL =
    "_Z14TestGetFnAttrsPN10impala_udf15FunctionContextERKNS_10DecimalValERNS_10BooleanValERNS_9StringValE";

const int ARG0_PRECISION = 10;
const int ARG0_SCALE = 2;
const int ARG1_LEN = 1;
const int RET_PRECISION = 10;
const int RET_SCALE = 1;

class ExprCodegenTest : public ::testing::Test {
 protected:
  scoped_ptr<TestEnv> test_env_;
  RuntimeState* runtime_state_;
  FunctionContext* fn_ctx_;
  FnAttr fn_type_attr_;

  int InlineConstFnAttrs(const Expr* expr, LlvmCodeGen* codegen, llvm::Function* fn) {
    FunctionContext::TypeDesc ret_type = AnyValUtil::ColumnTypeToTypeDesc(expr->type());
    vector<FunctionContext::TypeDesc> arg_types;
    for (const Expr* child : expr->children()) {
      arg_types.push_back(AnyValUtil::ColumnTypeToTypeDesc(child->type()));
    }
    return codegen->InlineConstFnAttrs(ret_type, arg_types, fn);
  }

  Status CreateFromFile(const string& filename, scoped_ptr<LlvmCodeGen>* codegen) {
    RETURN_IF_ERROR(LlvmCodeGen::CreateFromFile(runtime_state_,
        runtime_state_->obj_pool(), NULL, filename, "test", codegen));
    return (*codegen)->MaterializeModule();
  }

  virtual void SetUp() {
    TQueryOptions query_options;
    query_options.__set_disable_codegen(false);
    query_options.__set_decimal_v2(true);
    test_env_.reset(new TestEnv());
    ASSERT_OK(test_env_->Init());
    ASSERT_OK(test_env_->CreateQueryState(0, &query_options, &runtime_state_));

    FunctionContext::TypeDesc return_type;
    return_type.type = FunctionContext::TYPE_DECIMAL;
    return_type.precision = RET_PRECISION;
    return_type.scale = RET_SCALE;

    FunctionContext::TypeDesc arg0_type;
    arg0_type.type = FunctionContext::TYPE_DECIMAL;
    arg0_type.precision = ARG0_PRECISION;
    arg0_type.scale = ARG0_SCALE;

    FunctionContext::TypeDesc arg1_type;
    arg1_type.type = FunctionContext::TYPE_BOOLEAN;

    FunctionContext::TypeDesc arg2_type;
    arg2_type.type = FunctionContext::TYPE_STRING;

    vector<FunctionContext::TypeDesc> arg_types;
    arg_types.push_back(arg0_type);
    arg_types.push_back(arg1_type);
    arg_types.push_back(arg2_type);

    fn_ctx_ = UdfTestHarness::CreateTestContext(return_type, arg_types, runtime_state_);

    // Initialize fn_ctx_ with constants
    memset(&fn_type_attr_, -1, sizeof(FnAttr));
    fn_ctx_->SetFunctionState(FunctionContext::THREAD_LOCAL, &fn_type_attr_);
  }

  virtual void TearDown() {
    fn_ctx_->impl()->Close();
    delete fn_ctx_;
    runtime_state_ = NULL;
    test_env_.reset();
  }

  void CheckFnAttr() {
    EXPECT_EQ(fn_type_attr_.return_type_size, 8);
    EXPECT_EQ(fn_type_attr_.arg0_type_size, 8);
    EXPECT_EQ(fn_type_attr_.arg1_type_size, ARG1_LEN);
    EXPECT_EQ(fn_type_attr_.arg2_type_size, 0); // varlen
  }

  static bool VerifyFunction(LlvmCodeGen* codegen, llvm::Function* fn) {
    return codegen->VerifyFunction(fn);
  }

  static void ResetVerification(LlvmCodeGen* codegen) {
    codegen->ResetVerification();
  }
};

TExprNode CreateBooleanLiteral() {
  TScalarType scalar_type;
  scalar_type.type = TPrimitiveType::BOOLEAN;

  TTypeNode type;
  type.type = TTypeNodeType::SCALAR;
  type.__set_scalar_type(scalar_type);

  TColumnType col_type;
  col_type.__set_types(vector<TTypeNode>(1, type));

  TBoolLiteral bool_literal;
  bool_literal.__set_value(true);

  TExprNode expr;
  expr.node_type = TExprNodeType::BOOL_LITERAL;
  expr.type = col_type;
  expr.num_children = 0;
  expr.__set_bool_literal(bool_literal);
  return expr;
}

TExprNode CreateDecimalLiteral(int precision, int scale) {
  TScalarType scalar_type;
  scalar_type.type = TPrimitiveType::DECIMAL;
  scalar_type.__set_precision(precision);
  scalar_type.__set_scale(scale);

  TTypeNode type;
  type.type = TTypeNodeType::SCALAR;
  type.__set_scalar_type(scalar_type);

  TColumnType col_type;
  col_type.__set_types(vector<TTypeNode>(1, type));

  TDecimalLiteral decimal_literal;
  decimal_literal.value = "\1";

  TExprNode expr;
  expr.node_type = TExprNodeType::DECIMAL_LITERAL;
  expr.type = col_type;
  expr.num_children = 0;
  expr.__set_decimal_literal(decimal_literal);
  return expr;
}

// len > 0 => char
TExprNode CreateStringLiteral(int len = -1) {
  TScalarType scalar_type;
  scalar_type.type = len > 0 ? TPrimitiveType::VARCHAR : TPrimitiveType::STRING;
  if (len > 0) scalar_type.__set_len(len);

  TTypeNode type;
  type.type = TTypeNodeType::SCALAR;
  type.__set_scalar_type(scalar_type);

  TColumnType col_type;
  col_type.__set_types(vector<TTypeNode>(1, type));

  TStringLiteral string_literal;
  string_literal.value = "\1";

  TExprNode expr;
  expr.node_type = TExprNodeType::STRING_LITERAL;
  expr.type = col_type;
  expr.num_children = 0;
  expr.__set_string_literal(string_literal);
  return expr;
}

// Creates a function call to TestGetFnAttrs() in test-udfs.h
TExprNode CreateFunctionCall(vector<TExprNode> children, int precision, int scale) {
  TScalarType scalar_type;
  scalar_type.type = TPrimitiveType::DECIMAL;
  scalar_type.__set_precision(precision);
  scalar_type.__set_scale(scale);

  TTypeNode type;
  type.type = TTypeNodeType::SCALAR;
  type.__set_scalar_type(scalar_type);

  TColumnType col_type;
  col_type.__set_types(vector<TTypeNode>(1, type));

  TFunctionName fn_name;
  fn_name.function_name = "test_get_type_attr";

  TScalarFunction scalar_fn;
  scalar_fn.symbol = TEST_GET_FN_ATTR_SYMBOL;

  TFunction fn;
  fn.name = fn_name;
  fn.binary_type = TFunctionBinaryType::IR;
  for (const TExprNode& child: children) {
    fn.arg_types.push_back(child.type);
  }
  fn.ret_type = col_type;
  fn.has_var_args = false;
  fn.__set_scalar_fn(scalar_fn);

  TExprNode expr;
  expr.node_type = TExprNodeType::FUNCTION_CALL;
  expr.type = col_type;
  expr.num_children = children.size();
  expr.__set_fn(fn);
  return expr;
}

TEST_F(ExprCodegenTest, TestGetConstFnAttrsInterpreted) {
  // Call fn and check results'. The input is of type Decimal(10,2) (i.e. 10000.25) and
  // the output type is Decimal(10,1) (i.e. 10000.3). The precision and scale of arguments
  // and return types are encoded above (ARG0_*, RET_*);
  int64_t v = 1000025;
  DecimalVal arg0_val(v);
  BooleanVal arg1_val;
  StringVal arg2_val;
  DecimalVal result = TestGetFnAttrs(fn_ctx_, arg0_val, arg1_val, arg2_val);
  // sanity check result
  EXPECT_EQ(result.is_null, false);
  EXPECT_EQ(result.val8, 100003);
  CheckFnAttr();
}

TEST_F(ExprCodegenTest, TestInlineConstFnAttrs) {
  // Setup thrift descriptors
  TExprNode arg0 = CreateDecimalLiteral(ARG0_PRECISION, ARG0_SCALE);
  TExprNode arg1 = CreateBooleanLiteral();
  TExprNode arg2 = CreateStringLiteral();

  vector<TExprNode> exprs;
  exprs.push_back(arg0);
  exprs.push_back(arg1);
  exprs.push_back(arg2);

  TExprNode fn_call = CreateFunctionCall(exprs, RET_PRECISION, RET_SCALE);
  exprs.insert(exprs.begin(), fn_call);

  TExpr texpr;
  texpr.__set_nodes(exprs);

  // Create Expr
  MemTracker tracker;
  ScalarExpr* expr;
  ASSERT_OK(ScalarExpr::Create(texpr, RowDescriptor(), runtime_state_, &expr));

  // Get TestGetFnAttrs() IR function
  stringstream test_udf_file;
  test_udf_file << getenv("IMPALA_HOME") << "/be/build/latest/exprs/expr-codegen-test.ll";
  scoped_ptr<LlvmCodeGen> codegen;
  ASSERT_OK(CreateFromFile(test_udf_file.str(), &codegen));
  llvm::Function* fn = codegen->GetFunction(TEST_GET_FN_ATTR_SYMBOL, false);
  ASSERT_TRUE(fn != NULL);

  // Function verification should fail because we haven't inlined GetTypeAttr() calls
  bool verification_succeeded = VerifyFunction(codegen.get(), fn);
  EXPECT_FALSE(verification_succeeded);

  // Call InlineConstFnAttrs() and rerun verification
  int replaced = InlineConstFnAttrs(expr, codegen.get(), fn);
  EXPECT_EQ(replaced, 19);
  ResetVerification(codegen.get());
  verification_succeeded = VerifyFunction(codegen.get(), fn);
  EXPECT_TRUE(verification_succeeded) << CodeGenUtil::Print(fn);

  // Compile module
  fn = codegen->FinalizeFunction(fn);
  ASSERT_TRUE(fn != NULL);
  void* fn_ptr;
  codegen->AddFunctionToJit(fn, &fn_ptr);
  EXPECT_TRUE(codegen->FinalizeModule().ok()) << LlvmCodeGen::Print(fn);

  // Call fn and check results'. The input is of type Decimal(10,2) (i.e. 10000.25) and
  // the output type is Decimal(10,1) (i.e. 10000.3). The precision and scale of arguments
  // and return types are encoded above (ARG0_*, RET_*);
  int64_t v = 1000025;
  DecimalVal arg0_val(v);
  typedef DecimalVal (*TestGetFnAttrs)(FunctionContext*, const DecimalVal&);
  DecimalVal result = reinterpret_cast<TestGetFnAttrs>(fn_ptr)(fn_ctx_, arg0_val);
  // sanity check result
  EXPECT_EQ(result.is_null, false);
  EXPECT_EQ(result.val8, 100003);
  CheckFnAttr();
  codegen->Close();
}

}

using namespace impala;

int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
  InitCommonRuntime(argc, argv, true, TestInfo::BE_TEST);
  InitFeSupport();
  ABORT_IF_ERROR(LlvmCodeGen::InitializeLlvm());

  return RUN_ALL_TESTS();
}

#endif
