blob: d589c75eca6dbadfd0c65e40a01d77f2b449e527 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <iostream>
#include "common/logging.h"
#include "udf/uda-test-harness.h"
#include "testutil/gtest-util.h"
#include "testutil/test-udas.h"
#include "common/names.h"
using std::min;
using namespace impala;
using namespace impala_udf;
//-------------------------------- Count ------------------------------------
// Example of implementing Count(int_col).
// The input type is: int
// The intermediate type is bigint
// the return type is bigint
void CountInit(FunctionContext* context, BigIntVal* val) {
val->is_null = false;
val->val = 0;
}
void CountUpdate(FunctionContext* context, const IntVal& input, BigIntVal* val) {
// BigIntVal is the same ptr as what was passed to CountInit
if (input.is_null) return;
++val->val;
}
void CountMerge(FunctionContext* context, const BigIntVal& src, BigIntVal* dst) {
dst->val += src.val;
}
BigIntVal CountFinalize(FunctionContext* context, const BigIntVal& val) {
return val;
}
//-------------------------------- Count(...) ------------------------------------
// Example of implementing Count(...)
// The input type is: multiple ints
// The intermediate type is bigint
// the return type is bigint
void Count2Update(FunctionContext* context, const IntVal& input1, const IntVal& input2,
BigIntVal* val) {
val->val += (!input1.is_null + !input2.is_null);
}
void Count3Update(FunctionContext* context, const IntVal& input1, const IntVal& input2,
const IntVal& input3, BigIntVal* val) {
val->val += (!input1.is_null + !input2.is_null + !input3.is_null);
}
void Count4Update(FunctionContext* context, const IntVal& input1, const IntVal& input2,
const IntVal& input3, const IntVal& input4, BigIntVal* val) {
val->val += (!input1.is_null + !input2.is_null + !input3.is_null + !input4.is_null);
}
//-------------------------------- Min(String) ------------------------------------
// Example of implementing MIN for strings.
// The input type is: STRING
// The intermediate type is BufferVal
// the return type is STRING
// This is a little more sophisticated since the result buffers are reused (it grows
// to the longest result string).
struct MinState {
uint8_t* value;
int len;
int buffer_len;
void Set(FunctionContext* context, const StringVal& val) {
if (buffer_len < val.len) {
context->Free(value);
value = context->Allocate(val.len);
buffer_len = val.len;
}
memcpy(value, val.ptr, val.len);
len = val.len;
}
};
// Initialize the MinState scratch space
void MinInit(FunctionContext* context, BufferVal* val) {
MinState* state = reinterpret_cast<MinState*>(*val);
state->value = NULL;
state->buffer_len = 0;
}
// Update the min value, comparing with the current value in MinState
void MinUpdate(FunctionContext* context, const StringVal& input, BufferVal* val) {
if (input.is_null) return;
MinState* state = reinterpret_cast<MinState*>(*val);
if (state->value == NULL) {
state->Set(context, input);
return;
}
int cmp = memcmp(input.ptr, state->value, ::min(input.len, state->len));
if (cmp < 0 || (cmp == 0 && input.len < state->len)) {
state->Set(context, input);
}
}
// Serialize the state into the min string
BufferVal MinSerialize(FunctionContext* context, const BufferVal& intermediate) {
MinState* state = reinterpret_cast<MinState*>(intermediate);
if (state->value == NULL) return intermediate;
// Hack to persist the intermediate state's value without leaking.
// TODO: revisit BufferVal and design a better way to do this
StringVal copy_buffer(context, state->len);
memcpy(copy_buffer.ptr, state->value, state->len);
context->Free(state->value);
state->value = copy_buffer.ptr;
return intermediate;
}
// Merge is the same as Update since the serialized format is the raw input format
void MinMerge(FunctionContext* context, const BufferVal& src, BufferVal* dst) {
const MinState* src_state = reinterpret_cast<const MinState*>(src);
if (src_state->value == NULL) return;
MinUpdate(context, StringVal(src_state->value, src_state->len), dst);
}
// Finalize also just returns the string so is the same as MinSerialize.
StringVal MinFinalize(FunctionContext* context, const BufferVal& val) {
const MinState* state = reinterpret_cast<const MinState*>(val);
if (state->value == NULL) return StringVal::null();
StringVal result = StringVal::CopyFrom(context, state->value, state->len);
context->Free(state->value);
return result;
}
//----------------------------- Bits after Xor ------------------------------------
// Example of a UDA that xors all the input bits and then returns the number of
// resulting bits that are set. This illustrates where the result and intermediate
// are the same type, but a transformation is still needed in Finialize()
// The input type is: double
// The intermediate type is bigint
// the return type is bigint
void XorInit(FunctionContext* context, BigIntVal* val) {
val->is_null = false;
val->val = 0;
}
void XorUpdate(FunctionContext* context, const double* input, BigIntVal* val) {
// BigIntVal is the same ptr as what was passed to CountInit
if (input == NULL) return;
val->val |= *reinterpret_cast<const int64_t*>(input);
}
void XorMerge(FunctionContext* context, const BigIntVal& src, BigIntVal* dst) {
dst->val |= src.val;
}
BigIntVal XorFinalize(FunctionContext* context, const BigIntVal& val) {
int64_t set_bits = 0;
// Do popcnt on val
// set_bits = popcnt(val.val);
return BigIntVal(set_bits);
}
//--------------------------- HLL(Distinct Estimate) ---------------------------------
// Example of implementing distinct estimate. As an example, we will compress the
// intermediate buffer.
// Note: this is not the actual algorithm but a sketch of how it would be implemented
// with the UDA interface.
// The input type is: bigint
// The intermediate type is string (fixed at 256 bytes)
// the return type is bigint
void DistinctEstimateInit(FunctionContext* context, StringVal* val) {
// Since this is known, this will be allocated to 256 bytes.
assert(val->len == 256);
memset(val->ptr, 0, 256);
}
void DistinctEstimatUpdate(FunctionContext* context,
const int64_t* input, StringVal* val) {
if (input == NULL) return;
for (int i = 0; i < 256; ++i) {
int hash = 0;
// Hash(input) with the ith hash function
// hash = Hash(*input, i);
val->ptr[i] = hash;
}
}
StringVal DistinctEstimatSerialize(FunctionContext* context,
const StringVal& intermediate) {
int compressed_size = 0;
uint8_t* result = NULL; // SnappyCompress(intermediate.ptr, intermediate.len);
return StringVal(result, compressed_size);
}
void DistinctEstimateMerge(FunctionContext* context, const StringVal& src, StringVal* dst) {
uint8_t* src_uncompressed = NULL; // SnappyUncompress(src.ptr, src.len);
for (int i = 0; i < 256; ++i) {
dst->ptr[i] ^= src_uncompressed[i];
}
}
BigIntVal DistinctEstimateFinalize(FunctionContext* context, const StringVal& val) {
int64_t set_bits = 0;
// Do popcnt on val
// set_bits = popcnt(val.val);
return BigIntVal(set_bits);
}
TEST(CountTest, Basic) {
UdaTestHarness<BigIntVal, BigIntVal, IntVal> test(
CountInit, CountUpdate, CountMerge, NULL, CountFinalize);
vector<IntVal> no_nulls;
no_nulls.resize(1000);
EXPECT_TRUE(test.Execute(no_nulls, BigIntVal(no_nulls.size()))) << test.GetErrorMsg();
EXPECT_FALSE(test.Execute(no_nulls, BigIntVal(100))) << test.GetErrorMsg();
}
TEST(CountMultiArgTest, Basic) {
int num = 1000;
vector<IntVal> no_nulls;
no_nulls.resize(num);
UdaTestHarness2<BigIntVal, BigIntVal, IntVal, IntVal> test2(
CountInit, Count2Update, CountMerge, NULL, CountFinalize);
EXPECT_TRUE(test2.Execute(no_nulls, no_nulls, BigIntVal(2 * num)));
EXPECT_FALSE(test2.Execute(no_nulls, no_nulls, BigIntVal(100)));
UdaTestHarness3<BigIntVal, BigIntVal, IntVal, IntVal, IntVal> test3(
CountInit, Count3Update, CountMerge, NULL, CountFinalize);
EXPECT_TRUE(test3.Execute(no_nulls, no_nulls, no_nulls, BigIntVal(3 * num)));
UdaTestHarness4<BigIntVal, BigIntVal, IntVal, IntVal, IntVal, IntVal> test4(
CountInit, Count4Update, CountMerge, NULL, CountFinalize);
EXPECT_TRUE(test4.Execute(no_nulls, no_nulls, no_nulls, no_nulls, BigIntVal(4 * num)));
}
bool FuzzyCompare(const BigIntVal& r1, const BigIntVal& r2) {
if (r1.is_null && r2.is_null) return true;
if (r1.is_null || r2.is_null) return false;
return std::abs(r1.val - r2.val) <= 1;
}
TEST(CountTest, FuzzyEquals) {
UdaTestHarness<BigIntVal, BigIntVal, IntVal> test(
CountInit, CountUpdate, CountMerge, NULL, CountFinalize);
vector<IntVal> no_nulls;
no_nulls.resize(1000);
EXPECT_TRUE(test.Execute(no_nulls, BigIntVal(1000))) << test.GetErrorMsg();
EXPECT_FALSE(test.Execute(no_nulls, BigIntVal(999))) << test.GetErrorMsg();
test.SetResultComparator(FuzzyCompare);
EXPECT_TRUE(test.Execute(no_nulls, BigIntVal(1000))) << test.GetErrorMsg();
EXPECT_TRUE(test.Execute(no_nulls, BigIntVal(999))) << test.GetErrorMsg();
EXPECT_FALSE(test.Execute(no_nulls, BigIntVal(998))) << test.GetErrorMsg();
}
TEST(MinTest, Basic) {
UdaTestHarness<StringVal, BufferVal, StringVal> test(
MinInit, MinUpdate, MinMerge, MinSerialize, MinFinalize);
test.SetIntermediateSize(sizeof(MinState));
vector<StringVal> values;
values.push_back(StringVal("BBB"));
EXPECT_TRUE(test.Execute(values, StringVal("BBB"))) << test.GetErrorMsg();
values.push_back(StringVal("AA"));
EXPECT_TRUE(test.Execute(values, StringVal("AA"))) << test.GetErrorMsg();
values.push_back(StringVal("CCC"));
EXPECT_TRUE(test.Execute(values, StringVal("AA"))) << test.GetErrorMsg();
values.push_back(StringVal("ABCDEF"));
values.push_back(StringVal("AABCDEF"));
values.push_back(StringVal("A"));
EXPECT_TRUE(test.Execute(values, StringVal("A"))) << test.GetErrorMsg();
values.clear();
values.push_back(StringVal::null());
EXPECT_TRUE(test.Execute(values, StringVal::null())) << test.GetErrorMsg();
values.push_back(StringVal("ZZZ"));
EXPECT_TRUE(test.Execute(values, StringVal("ZZZ"))) << test.GetErrorMsg();
}
TEST(MemTest, Basic) {
UdaTestHarness<BigIntVal, BigIntVal, BigIntVal> test(
::MemTestInit, ::MemTestUpdate, ::MemTestMerge, ::MemTestSerialize,
::MemTestFinalize);
vector<BigIntVal> input;
for (int i = 0; i < 10; ++i) {
input.push_back(10);
}
EXPECT_TRUE(test.Execute(input, BigIntVal(100))) << test.GetErrorMsg();
UdaTestHarness<BigIntVal, BigIntVal, BigIntVal> test_leak(
::MemTestInit, ::MemTestUpdate, ::MemTestMerge, NULL, ::MemTestFinalize);
EXPECT_FALSE(test_leak.Execute(input, BigIntVal(100))) << test.GetErrorMsg();
}
IMPALA_TEST_MAIN();