blob: e1c5c501bb74d925f89e16f3bf6b8854f4371504 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// The following is cross-compiled to native code and IR, and used in the test below
#include "exprs/decimal-operators.h"
#include "exprs/scalar-expr.h"
#include "udf/udf.h"
#ifdef IR_COMPILE
#include "exprs/decimal-operators-ir.cc"
#endif
using namespace impala;
using namespace impala_udf;
// TestGetTypeAttrs() fills in the following constants
struct FnAttr {
int return_type_size;
int arg0_type_size;
int arg1_type_size;
int arg2_type_size;
};
DecimalVal TestGetFnAttrs(
FunctionContext* ctx, const DecimalVal& arg0, BooleanVal& arg1, StringVal& arg2) {
FnAttr* state = reinterpret_cast<FnAttr*>(
ctx->GetFunctionState(FunctionContext::THREAD_LOCAL));
state->return_type_size =
ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_SIZE);
state->arg0_type_size =
ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 0);
state->arg1_type_size =
ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 1);
state->arg2_type_size =
ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 2);
// This function and its callees call FunctionContextImpl::GetConstFnAttr();
return DecimalOperators::CastToDecimalVal(ctx, arg0);
}
// Don't compile the actual test to IR
#ifndef IR_COMPILE
#include "testutil/gtest-util.h"
#include "codegen/codegen-util.h"
#include "codegen/llvm-codegen.h"
#include "common/init.h"
#include "exprs/anyval-util.h"
#include "exprs/scalar-expr.h"
#include "exprs/scalar-expr-evaluator.h"
#include "runtime/exec-env.h"
#include "runtime/mem-tracker.h"
#include "runtime/runtime-state.h"
#include "runtime/test-env.h"
#include "service/fe-support.h"
#include "udf/udf-internal.h"
#include "udf/udf-test-harness.h"
#include "gen-cpp/Exprs_types.h"
#include "common/names.h"
namespace impala {
const char* TEST_GET_FN_ATTR_SYMBOL =
"_Z14TestGetFnAttrsPN10impala_udf15FunctionContextERKNS_10DecimalValERNS_10BooleanValERNS_9StringValE";
const int ARG0_PRECISION = 10;
const int ARG0_SCALE = 2;
const int ARG1_LEN = 1;
const int RET_PRECISION = 10;
const int RET_SCALE = 1;
class ExprCodegenTest : public ::testing::Test {
protected:
scoped_ptr<TestEnv> test_env_;
RuntimeState* runtime_state_;
FunctionContext* fn_ctx_;
FnAttr fn_type_attr_;
int InlineConstFnAttrs(const Expr* expr, LlvmCodeGen* codegen, llvm::Function* fn) {
FunctionContext::TypeDesc ret_type = AnyValUtil::ColumnTypeToTypeDesc(expr->type());
vector<FunctionContext::TypeDesc> arg_types;
for (const Expr* child : expr->children()) {
arg_types.push_back(AnyValUtil::ColumnTypeToTypeDesc(child->type()));
}
return codegen->InlineConstFnAttrs(ret_type, arg_types, fn);
}
Status CreateFromFile(const string& filename, scoped_ptr<LlvmCodeGen>* codegen) {
RETURN_IF_ERROR(LlvmCodeGen::CreateFromFile(runtime_state_,
runtime_state_->obj_pool(), NULL, filename, "test", codegen));
return (*codegen)->MaterializeModule();
}
virtual void SetUp() {
TQueryOptions query_options;
query_options.__set_disable_codegen(false);
query_options.__set_decimal_v2(true);
test_env_.reset(new TestEnv());
ASSERT_OK(test_env_->Init());
ASSERT_OK(test_env_->CreateQueryState(0, &query_options, &runtime_state_));
FunctionContext::TypeDesc return_type;
return_type.type = FunctionContext::TYPE_DECIMAL;
return_type.precision = RET_PRECISION;
return_type.scale = RET_SCALE;
FunctionContext::TypeDesc arg0_type;
arg0_type.type = FunctionContext::TYPE_DECIMAL;
arg0_type.precision = ARG0_PRECISION;
arg0_type.scale = ARG0_SCALE;
FunctionContext::TypeDesc arg1_type;
arg1_type.type = FunctionContext::TYPE_BOOLEAN;
FunctionContext::TypeDesc arg2_type;
arg2_type.type = FunctionContext::TYPE_STRING;
vector<FunctionContext::TypeDesc> arg_types;
arg_types.push_back(arg0_type);
arg_types.push_back(arg1_type);
arg_types.push_back(arg2_type);
fn_ctx_ = UdfTestHarness::CreateTestContext(return_type, arg_types, runtime_state_);
// Initialize fn_ctx_ with constants
memset(&fn_type_attr_, -1, sizeof(FnAttr));
fn_ctx_->SetFunctionState(FunctionContext::THREAD_LOCAL, &fn_type_attr_);
}
virtual void TearDown() {
fn_ctx_->impl()->Close();
delete fn_ctx_;
runtime_state_ = NULL;
test_env_.reset();
}
void CheckFnAttr() {
EXPECT_EQ(fn_type_attr_.return_type_size, 8);
EXPECT_EQ(fn_type_attr_.arg0_type_size, 8);
EXPECT_EQ(fn_type_attr_.arg1_type_size, ARG1_LEN);
EXPECT_EQ(fn_type_attr_.arg2_type_size, 0); // varlen
}
static bool VerifyFunction(LlvmCodeGen* codegen, llvm::Function* fn) {
return codegen->VerifyFunction(fn);
}
static void ResetVerification(LlvmCodeGen* codegen) {
codegen->ResetVerification();
}
};
TExprNode CreateBooleanLiteral() {
TScalarType scalar_type;
scalar_type.type = TPrimitiveType::BOOLEAN;
TTypeNode type;
type.type = TTypeNodeType::SCALAR;
type.__set_scalar_type(scalar_type);
TColumnType col_type;
col_type.__set_types(vector<TTypeNode>(1, type));
TBoolLiteral bool_literal;
bool_literal.__set_value(true);
TExprNode expr;
expr.node_type = TExprNodeType::BOOL_LITERAL;
expr.type = col_type;
expr.num_children = 0;
expr.__set_bool_literal(bool_literal);
return expr;
}
TExprNode CreateDecimalLiteral(int precision, int scale) {
TScalarType scalar_type;
scalar_type.type = TPrimitiveType::DECIMAL;
scalar_type.__set_precision(precision);
scalar_type.__set_scale(scale);
TTypeNode type;
type.type = TTypeNodeType::SCALAR;
type.__set_scalar_type(scalar_type);
TColumnType col_type;
col_type.__set_types(vector<TTypeNode>(1, type));
TDecimalLiteral decimal_literal;
decimal_literal.value = "\1";
TExprNode expr;
expr.node_type = TExprNodeType::DECIMAL_LITERAL;
expr.type = col_type;
expr.num_children = 0;
expr.__set_decimal_literal(decimal_literal);
return expr;
}
// len > 0 => char
TExprNode CreateStringLiteral(int len = -1) {
TScalarType scalar_type;
scalar_type.type = len > 0 ? TPrimitiveType::VARCHAR : TPrimitiveType::STRING;
if (len > 0) scalar_type.__set_len(len);
TTypeNode type;
type.type = TTypeNodeType::SCALAR;
type.__set_scalar_type(scalar_type);
TColumnType col_type;
col_type.__set_types(vector<TTypeNode>(1, type));
TStringLiteral string_literal;
string_literal.value = "\1";
TExprNode expr;
expr.node_type = TExprNodeType::STRING_LITERAL;
expr.type = col_type;
expr.num_children = 0;
expr.__set_string_literal(string_literal);
return expr;
}
// Creates a function call to TestGetFnAttrs() in test-udfs.h
TExprNode CreateFunctionCall(vector<TExprNode> children, int precision, int scale) {
TScalarType scalar_type;
scalar_type.type = TPrimitiveType::DECIMAL;
scalar_type.__set_precision(precision);
scalar_type.__set_scale(scale);
TTypeNode type;
type.type = TTypeNodeType::SCALAR;
type.__set_scalar_type(scalar_type);
TColumnType col_type;
col_type.__set_types(vector<TTypeNode>(1, type));
TFunctionName fn_name;
fn_name.function_name = "test_get_type_attr";
TScalarFunction scalar_fn;
scalar_fn.symbol = TEST_GET_FN_ATTR_SYMBOL;
TFunction fn;
fn.name = fn_name;
fn.binary_type = TFunctionBinaryType::IR;
for (const TExprNode& child: children) {
fn.arg_types.push_back(child.type);
}
fn.ret_type = col_type;
fn.has_var_args = false;
fn.__set_scalar_fn(scalar_fn);
TExprNode expr;
expr.node_type = TExprNodeType::FUNCTION_CALL;
expr.type = col_type;
expr.num_children = children.size();
expr.__set_fn(fn);
return expr;
}
TEST_F(ExprCodegenTest, TestGetConstFnAttrsInterpreted) {
// Call fn and check results'. The input is of type Decimal(10,2) (i.e. 10000.25) and
// the output type is Decimal(10,1) (i.e. 10000.3). The precision and scale of arguments
// and return types are encoded above (ARG0_*, RET_*);
int64_t v = 1000025;
DecimalVal arg0_val(v);
BooleanVal arg1_val;
StringVal arg2_val;
DecimalVal result = TestGetFnAttrs(fn_ctx_, arg0_val, arg1_val, arg2_val);
// sanity check result
EXPECT_EQ(result.is_null, false);
EXPECT_EQ(result.val8, 100003);
CheckFnAttr();
}
TEST_F(ExprCodegenTest, TestInlineConstFnAttrs) {
// Setup thrift descriptors
TExprNode arg0 = CreateDecimalLiteral(ARG0_PRECISION, ARG0_SCALE);
TExprNode arg1 = CreateBooleanLiteral();
TExprNode arg2 = CreateStringLiteral();
vector<TExprNode> exprs;
exprs.push_back(arg0);
exprs.push_back(arg1);
exprs.push_back(arg2);
TExprNode fn_call = CreateFunctionCall(exprs, RET_PRECISION, RET_SCALE);
exprs.insert(exprs.begin(), fn_call);
TExpr texpr;
texpr.__set_nodes(exprs);
// Create Expr
MemTracker tracker;
ScalarExpr* expr;
ASSERT_OK(ScalarExpr::Create(texpr, RowDescriptor(), runtime_state_, &expr));
// Get TestGetFnAttrs() IR function
stringstream test_udf_file;
test_udf_file << getenv("IMPALA_HOME") << "/be/build/latest/exprs/expr-codegen-test.ll";
scoped_ptr<LlvmCodeGen> codegen;
ASSERT_OK(CreateFromFile(test_udf_file.str(), &codegen));
llvm::Function* fn = codegen->GetFunction(TEST_GET_FN_ATTR_SYMBOL, false);
ASSERT_TRUE(fn != NULL);
// Function verification should fail because we haven't inlined GetTypeAttr() calls
bool verification_succeeded = VerifyFunction(codegen.get(), fn);
EXPECT_FALSE(verification_succeeded);
// Call InlineConstFnAttrs() and rerun verification
int replaced = InlineConstFnAttrs(expr, codegen.get(), fn);
EXPECT_EQ(replaced, 19);
ResetVerification(codegen.get());
verification_succeeded = VerifyFunction(codegen.get(), fn);
EXPECT_TRUE(verification_succeeded) << CodeGenUtil::Print(fn);
// Compile module
fn = codegen->FinalizeFunction(fn);
ASSERT_TRUE(fn != NULL);
void* fn_ptr;
codegen->AddFunctionToJit(fn, &fn_ptr);
EXPECT_TRUE(codegen->FinalizeModule().ok()) << LlvmCodeGen::Print(fn);
// Call fn and check results'. The input is of type Decimal(10,2) (i.e. 10000.25) and
// the output type is Decimal(10,1) (i.e. 10000.3). The precision and scale of arguments
// and return types are encoded above (ARG0_*, RET_*);
int64_t v = 1000025;
DecimalVal arg0_val(v);
typedef DecimalVal (*TestGetFnAttrs)(FunctionContext*, const DecimalVal&);
DecimalVal result = reinterpret_cast<TestGetFnAttrs>(fn_ptr)(fn_ctx_, arg0_val);
// sanity check result
EXPECT_EQ(result.is_null, false);
EXPECT_EQ(result.val8, 100003);
CheckFnAttr();
codegen->Close();
}
}
using namespace impala;
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
InitCommonRuntime(argc, argv, true, TestInfo::BE_TEST);
InitFeSupport();
ABORT_IF_ERROR(LlvmCodeGen::InitializeLlvm());
return RUN_ALL_TESTS();
}
#endif