| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| **/ |
| |
| #include <cstddef> |
| #include <cstdio> |
| #include <cstring> |
| #include <memory> |
| #include <utility> |
| #include <vector> |
| |
| #include "catalog/CatalogAttribute.hpp" |
| #include "catalog/CatalogRelation.hpp" |
| #include "catalog/CatalogRelationSchema.hpp" |
| #include "expressions/predicate/ComparisonPredicate.hpp" |
| #include "expressions/predicate/ConjunctionPredicate.hpp" |
| #include "expressions/predicate/DisjunctionPredicate.hpp" |
| #include "expressions/predicate/Predicate.hpp" |
| #include "expressions/scalar/Scalar.hpp" |
| #include "expressions/scalar/ScalarAttribute.hpp" |
| #include "expressions/scalar/ScalarBinaryExpression.hpp" |
| #include "expressions/scalar/ScalarCaseExpression.hpp" |
| #include "expressions/scalar/ScalarLiteral.hpp" |
| #include "expressions/scalar/ScalarUnaryExpression.hpp" |
| #include "storage/StorageBlockInfo.hpp" |
| #include "storage/TupleIdSequence.hpp" |
| #include "storage/ValueAccessor.hpp" |
| #include "types/TypeFactory.hpp" |
| #include "types/TypeID.hpp" |
| #include "types/TypedValue.hpp" |
| #include "types/containers/ColumnVector.hpp" |
| #include "types/containers/ColumnVectorsValueAccessor.hpp" |
| #include "types/operations/binary_operations/BinaryOperationFactory.hpp" |
| #include "types/operations/binary_operations/BinaryOperationID.hpp" |
| #include "types/operations/comparisons/ComparisonFactory.hpp" |
| #include "types/operations/comparisons/ComparisonID.hpp" |
| #include "types/operations/unary_operations/UnaryOperationFactory.hpp" |
| #include "types/operations/unary_operations/UnaryOperationID.hpp" |
| |
| #include "gtest/gtest.h" |
| |
| // NetBSD's libc has snprintf(), but it doesn't show up in the std namespace |
| // for C++. |
| #ifndef __NetBSD__ |
| using std::snprintf; |
| #endif |
| |
| namespace quickstep { |
| |
| class Type; |
| |
| class ScalarCaseExpressionTest : public ::testing::Test { |
| protected: |
| static constexpr std::size_t kNumSampleTuples = 1000; |
| |
| virtual void SetUp() { |
| sample_relation_.reset(new CatalogRelation(nullptr, "sample", 0)); |
| sample_relation_->addAttribute(new CatalogAttribute( |
| sample_relation_.get(), |
| "int_attr", |
| TypeFactory::GetType(kInt, true))); |
| sample_relation_->addAttribute(new CatalogAttribute( |
| sample_relation_.get(), |
| "double_attr", |
| TypeFactory::GetType(kDouble, true))); |
| sample_relation_->addAttribute(new CatalogAttribute( |
| sample_relation_.get(), |
| "varchar_attr", |
| TypeFactory::GetType(kVarChar, 20, true))); |
| |
| std::unique_ptr<NativeColumnVector> int_column( |
| new NativeColumnVector(TypeFactory::GetType(kInt, true), kNumSampleTuples)); |
| std::unique_ptr<NativeColumnVector> double_column( |
| new NativeColumnVector(TypeFactory::GetType(kDouble, true), kNumSampleTuples)); |
| std::unique_ptr<IndirectColumnVector> varchar_column( |
| new IndirectColumnVector(TypeFactory::GetType(kVarChar, 20, true), kNumSampleTuples)); |
| |
| for (std::size_t tuple_num = 0; tuple_num < kNumSampleTuples; ++tuple_num) { |
| if (tuple_num % 10 == 0) { |
| int_column->appendNullValue(); |
| } else { |
| int int_value = tuple_num; |
| int_column->appendUntypedValue(&int_value); |
| } |
| |
| if (tuple_num % 10 == 1) { |
| double_column->appendNullValue(); |
| } else { |
| double double_value = -static_cast<double>(tuple_num); |
| double_column->appendUntypedValue(&double_value); |
| } |
| |
| if (tuple_num % 10 == 2) { |
| TypedValue varchar_value(kVarChar); |
| varchar_column->appendTypedValue(std::move(varchar_value)); |
| } else { |
| char varchar_buffer[20]; |
| int written = snprintf(varchar_buffer, |
| sizeof(varchar_buffer), |
| "aa%05dzz", |
| static_cast<int>(tuple_num)); |
| ASSERT_LT(written, 20); |
| TypedValue varchar_value(kVarChar, varchar_buffer, std::strlen(varchar_buffer) + 1); |
| varchar_value.ensureNotReference(); |
| varchar_column->appendTypedValue(std::move(varchar_value)); |
| } |
| } |
| |
| sample_data_value_accessor_.addColumn(int_column.release()); |
| sample_data_value_accessor_.addColumn(double_column.release()); |
| sample_data_value_accessor_.addColumn(varchar_column.release()); |
| } |
| |
| std::unique_ptr<CatalogRelation> sample_relation_; |
| ColumnVectorsValueAccessor sample_data_value_accessor_; |
| }; |
| |
| constexpr std::size_t ScalarCaseExpressionTest::kNumSampleTuples; |
| |
| // Test a CASE expression that always goes to the same branch. |
| TEST_F(ScalarCaseExpressionTest, StaticBranchTest) { |
| std::vector<std::unique_ptr<Predicate>> when_predicates; |
| std::vector<std::unique_ptr<Scalar>> result_expressions; |
| |
| const Type &int_type = TypeFactory::GetType(kInt); |
| |
| // WHEN 1 > 2 THEN int_attr |
| when_predicates.emplace_back(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kGreater), |
| new ScalarLiteral(TypedValue(static_cast<int>(1)), int_type), |
| new ScalarLiteral(TypedValue(static_cast<int>(2)), int_type))); |
| result_expressions.emplace_back( |
| new ScalarAttribute(*sample_relation_->getAttributeById(0))); |
| |
| // WHEN 2 > 1 THEN 42 |
| when_predicates.emplace_back(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kGreater), |
| new ScalarLiteral(TypedValue(static_cast<int>(2)), int_type), |
| new ScalarLiteral(TypedValue(static_cast<int>(1)), int_type))); |
| result_expressions.emplace_back( |
| new ScalarLiteral(TypedValue(static_cast<int>(42)), int_type)); |
| |
| // WHEN int_attr > 50 THEN -int_attr |
| when_predicates.emplace_back(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kGreater), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(static_cast<int>(50)), int_type))); |
| result_expressions.emplace_back(new ScalarUnaryExpression( |
| UnaryOperationFactory::GetUnaryOperation(UnaryOperationID::kNegate), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)))); |
| |
| // ELSE 123 |
| ScalarCaseExpression case_expr( |
| TypeFactory::GetType(kInt, true), |
| std::move(when_predicates), |
| std::move(result_expressions), |
| new ScalarLiteral(TypedValue(static_cast<int>(123)), int_type)); |
| |
| ASSERT_TRUE(case_expr.hasStaticValue()); |
| ASSERT_EQ(kInt, case_expr.getStaticValue().getTypeID()); |
| ASSERT_FALSE(case_expr.getStaticValue().isNull()); |
| EXPECT_EQ(42, case_expr.getStaticValue().getLiteral<int>()); |
| } |
| |
| TEST_F(ScalarCaseExpressionTest, BasicComparisonAndLiteralTest) { |
| static const char kFirstLawString[] |
| = "A robot may not injure a human being or, through inaction, allow a " |
| "human being to come to harm."; |
| static const char kSecondLawString[] |
| = "A robot must obey orders given it by human beings except where such " |
| "orders would conflict with the First Law."; |
| static const char kThirdLawString[] |
| = "A robot must protect its own existence as long as such protection " |
| "does not conflict with the First or Second Law."; |
| |
| std::vector<std::unique_ptr<Predicate>> when_predicates; |
| std::vector<std::unique_ptr<Scalar>> result_expressions; |
| |
| const Type &int_type = TypeFactory::GetType(kInt); |
| const Type &double_type = TypeFactory::GetType(kDouble); |
| // Maximum length of 160 chars is longer than is needed for any of the |
| // sample strings. |
| const Type &varchar_type = TypeFactory::GetType(kVarChar, 160, false); |
| |
| // WHEN int_attr < [kNumSampleTuples / 2] THEN [kFirstLawString] |
| when_predicates.emplace_back(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(static_cast<int>(kNumSampleTuples / 2)), int_type))); |
| result_expressions.emplace_back(new ScalarLiteral( |
| TypedValue(kVarChar, kFirstLawString, std::strlen(kFirstLawString) + 1), |
| varchar_type)); |
| |
| // WHEN double_attr > [-kNumSampleTuples * 3.0 / 4.0] THEN [kSecondLawString] |
| when_predicates.emplace_back(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kGreater), |
| new ScalarAttribute(*sample_relation_->getAttributeById(1)), |
| new ScalarLiteral(TypedValue(-static_cast<double>(kNumSampleTuples) * 3.0 / 4.0), double_type))); |
| result_expressions.emplace_back(new ScalarLiteral( |
| TypedValue(kVarChar, kSecondLawString, std::strlen(kSecondLawString) + 1), |
| varchar_type)); |
| |
| // ELSE [kThirdLawString] |
| ScalarCaseExpression case_expr( |
| TypeFactory::GetType(kVarChar, 160, false), |
| std::move(when_predicates), |
| std::move(result_expressions), |
| new ScalarLiteral(TypedValue(kVarChar, kThirdLawString, std::strlen(kThirdLawString) + 1), |
| varchar_type)); |
| |
| ColumnVectorPtr result_cv( |
| case_expr.getAllValues(&sample_data_value_accessor_, nullptr, nullptr)); |
| ASSERT_FALSE(result_cv->isNative()); |
| const IndirectColumnVector &indirect_result_cv |
| = static_cast<const IndirectColumnVector&>(*result_cv); |
| EXPECT_EQ(kNumSampleTuples, indirect_result_cv.size()); |
| |
| for (std::size_t tuple_num = 0; |
| tuple_num < indirect_result_cv.size(); |
| ++tuple_num) { |
| if ((tuple_num % 10 != 0) && (tuple_num < kNumSampleTuples / 2)) { |
| EXPECT_STREQ(kFirstLawString, |
| static_cast<const char*>(indirect_result_cv.getUntypedValue(tuple_num))); |
| } else if ((tuple_num % 10 != 1) && (tuple_num < kNumSampleTuples * 3 / 4)) { |
| EXPECT_STREQ(kSecondLawString, |
| static_cast<const char*>(indirect_result_cv.getUntypedValue(tuple_num))); |
| } else { |
| EXPECT_STREQ(kThirdLawString, |
| static_cast<const char*>(indirect_result_cv.getUntypedValue(tuple_num))); |
| } |
| } |
| } |
| |
| // Same as above, but use a TupleIdSequenceAdapterValueAccessor that filters |
| // the input. |
| TEST_F(ScalarCaseExpressionTest, BasicComparisonAndLiteralWithFilteredInputTest) { |
| // Filter out every seventh tuple. |
| TupleIdSequence filter_sequence(kNumSampleTuples); |
| std::size_t num_filtered_tuples = 0; |
| for (std::size_t tuple_num = 0; |
| tuple_num < kNumSampleTuples; |
| ++tuple_num) { |
| if (tuple_num % 7 != 0) { |
| filter_sequence.set(tuple_num, true); |
| ++num_filtered_tuples; |
| } |
| } |
| |
| std::unique_ptr<ValueAccessor> filtered_accessor( |
| sample_data_value_accessor_.createSharedTupleIdSequenceAdapter(filter_sequence)); |
| |
| static const char kFirstLawString[] |
| = "A robot may not injure a human being or, through inaction, allow a " |
| "human being to come to harm."; |
| static const char kSecondLawString[] |
| = "A robot must obey orders given it by human beings except where such " |
| "orders would conflict with the First Law."; |
| static const char kThirdLawString[] |
| = "A robot must protect its own existence as long as such protection " |
| "does not conflict with the First or Second Law."; |
| |
| std::vector<std::unique_ptr<Predicate>> when_predicates; |
| std::vector<std::unique_ptr<Scalar>> result_expressions; |
| |
| const Type &int_type = TypeFactory::GetType(kInt); |
| const Type &double_type = TypeFactory::GetType(kDouble); |
| const Type &varchar_type = TypeFactory::GetType(kVarChar, 160, false); |
| |
| // WHEN int_attr < [kNumSampleTuples / 2] THEN [kFirstLawString] |
| when_predicates.emplace_back(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(static_cast<int>(kNumSampleTuples / 2)), int_type))); |
| result_expressions.emplace_back(new ScalarLiteral( |
| TypedValue(kVarChar, kFirstLawString, std::strlen(kFirstLawString) + 1), |
| varchar_type)); |
| |
| // WHEN double_attr > [kNumSampleTuples * 3.0 / 4.0] THEN [kSecondLawString] |
| when_predicates.emplace_back(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kGreater), |
| new ScalarAttribute(*sample_relation_->getAttributeById(1)), |
| new ScalarLiteral(TypedValue(-static_cast<double>(kNumSampleTuples) * 3.0 / 4.0), |
| double_type))); |
| result_expressions.emplace_back(new ScalarLiteral( |
| TypedValue(kVarChar, kSecondLawString, std::strlen(kSecondLawString) + 1), |
| varchar_type)); |
| |
| // ELSE [kThirdLawString] |
| ScalarCaseExpression case_expr( |
| TypeFactory::GetType(kVarChar, 160, false), |
| std::move(when_predicates), |
| std::move(result_expressions), |
| new ScalarLiteral(TypedValue(kVarChar, kThirdLawString, std::strlen(kThirdLawString) + 1), |
| varchar_type)); |
| |
| ColumnVectorPtr result_cv( |
| case_expr.getAllValues(filtered_accessor.get(), |
| nullptr /* sub_blocks_ref */, |
| nullptr /* cv_cache */)); |
| ASSERT_FALSE(result_cv->isNative()); |
| const IndirectColumnVector &indirect_result_cv |
| = static_cast<const IndirectColumnVector&>(*result_cv); |
| EXPECT_EQ(num_filtered_tuples, indirect_result_cv.size()); |
| |
| std::size_t original_tuple_num = 0; |
| for (std::size_t tuple_num = 0; |
| tuple_num < indirect_result_cv.size(); |
| ++tuple_num, ++original_tuple_num) { |
| if (original_tuple_num % 7 == 0) { |
| ++original_tuple_num; |
| } |
| |
| if ((original_tuple_num % 10 != 0) && (original_tuple_num < kNumSampleTuples / 2)) { |
| EXPECT_STREQ(kFirstLawString, |
| static_cast<const char*>(indirect_result_cv.getUntypedValue(tuple_num))); |
| } else if ((original_tuple_num % 10 != 1) && (original_tuple_num < kNumSampleTuples * 3 / 4)) { |
| EXPECT_STREQ(kSecondLawString, |
| static_cast<const char*>(indirect_result_cv.getUntypedValue(tuple_num))); |
| } else { |
| EXPECT_STREQ(kThirdLawString, |
| static_cast<const char*>(indirect_result_cv.getUntypedValue(tuple_num))); |
| } |
| } |
| } |
| |
| // If a tuple matches multiple WHEN clauses, get the value from the first one. |
| TEST_F(ScalarCaseExpressionTest, WhenClauseOrderTest) { |
| static const char kFirstLawString[] |
| = "A robot may not injure a human being or, through inaction, allow a " |
| "human being to come to harm."; |
| static const char kSecondLawString[] |
| = "A robot must obey orders given it by human beings except where such " |
| "orders would conflict with the First Law."; |
| static const char kThirdLawString[] |
| = "A robot must protect its own existence as long as such protection " |
| "does not conflict with the First or Second Law."; |
| |
| std::vector<std::unique_ptr<Predicate>> when_predicates; |
| std::vector<std::unique_ptr<Scalar>> result_expressions; |
| |
| const Type &int_type = TypeFactory::GetType(kInt); |
| const Type &varchar_type = TypeFactory::GetType(kVarChar, 160, false); |
| |
| // WHEN int_attr < 500 THEN [kFirstLawString] |
| when_predicates.emplace_back(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(static_cast<int>(500)), int_type))); |
| result_expressions.emplace_back(new ScalarLiteral( |
| TypedValue(kVarChar, kFirstLawString, std::strlen(kFirstLawString) + 1), |
| varchar_type)); |
| |
| // WHEN int_attr < 750 THEN [kSecondLawString] |
| when_predicates.emplace_back(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(static_cast<int>(750)), int_type))); |
| result_expressions.emplace_back(new ScalarLiteral( |
| TypedValue(kVarChar, kSecondLawString, std::strlen(kSecondLawString) + 1), |
| varchar_type)); |
| |
| // ELSE [kThirdLawString] |
| ScalarCaseExpression case_expr( |
| TypeFactory::GetType(kVarChar, 160, false), |
| std::move(when_predicates), |
| std::move(result_expressions), |
| new ScalarLiteral(TypedValue(kVarChar, kThirdLawString, std::strlen(kThirdLawString) + 1), |
| varchar_type)); |
| |
| ColumnVectorPtr result_cv( |
| case_expr.getAllValues(&sample_data_value_accessor_, |
| nullptr /* sub_blocks_ref */, |
| nullptr /* cv_cache */)); |
| ASSERT_FALSE(result_cv->isNative()); |
| const IndirectColumnVector &indirect_result_cv |
| = static_cast<const IndirectColumnVector&>(*result_cv); |
| EXPECT_EQ(kNumSampleTuples, indirect_result_cv.size()); |
| |
| for (std::size_t tuple_num = 0; |
| tuple_num < indirect_result_cv.size(); |
| ++tuple_num) { |
| if ((tuple_num % 10 != 0) && (tuple_num < 500)) { |
| EXPECT_STREQ(kFirstLawString, |
| static_cast<const char*>(indirect_result_cv.getUntypedValue(tuple_num))); |
| } else if ((tuple_num % 10 != 0) && (tuple_num < 750)) { |
| EXPECT_STREQ(kSecondLawString, |
| static_cast<const char*>(indirect_result_cv.getUntypedValue(tuple_num))); |
| } else { |
| EXPECT_STREQ(kThirdLawString, |
| static_cast<const char*>(indirect_result_cv.getUntypedValue(tuple_num))); |
| } |
| } |
| } |
| |
| TEST_F(ScalarCaseExpressionTest, ComplexConjunctionAndCalculatedExpressionTest) { |
| const Type &int_type = TypeFactory::GetType(kInt); |
| const Type &varchar_type = TypeFactory::GetType(kVarChar, 20, false); |
| |
| std::vector<std::unique_ptr<Predicate>> when_predicates; |
| std::vector<std::unique_ptr<Scalar>> result_expressions; |
| |
| // WHEN (int_attr * 2) + double_attr < 500 AND varchar_attr > 'aa002' |
| // THEN -int_attr |
| static const char kStringLit1[] = "aa002"; |
| ConjunctionPredicate *first_case_predicate = new ConjunctionPredicate(); |
| first_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd), |
| new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kMultiply), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(2), int_type)), |
| new ScalarAttribute(*sample_relation_->getAttributeById(1))), |
| new ScalarLiteral(TypedValue(static_cast<int>(kNumSampleTuples / 2)), int_type))); |
| first_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kGreater), |
| new ScalarAttribute(*sample_relation_->getAttributeById(2)), |
| new ScalarLiteral(TypedValue(kVarChar, kStringLit1, std::strlen(kStringLit1) + 1), varchar_type))); |
| when_predicates.emplace_back(first_case_predicate); |
| |
| result_expressions.emplace_back(new ScalarUnaryExpression( |
| UnaryOperationFactory::GetUnaryOperation(UnaryOperationID::kNegate), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)))); |
| |
| // WHEN int_attr >= 50 AND (CASE WHEN varchar_attr < 'aa007' THEN 1 ELSE -1) > 0 |
| // THEN int_attr + 10 |
| static const char kStringLit2[] = "aa007"; |
| ConjunctionPredicate *second_case_predicate = new ConjunctionPredicate(); |
| second_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kGreaterOrEqual), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(50), int_type))); |
| std::vector<std::unique_ptr<Predicate>> subcase_preds; |
| subcase_preds.emplace_back(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarAttribute(*sample_relation_->getAttributeById(2)), |
| new ScalarLiteral(TypedValue(kVarChar, kStringLit2, std::strlen(kStringLit2) + 1), varchar_type))); |
| std::vector<std::unique_ptr<Scalar>> subcase_results; |
| subcase_results.emplace_back(new ScalarLiteral(TypedValue(1), int_type)); |
| second_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kGreater), |
| new ScalarCaseExpression( |
| TypeFactory::GetType(kInt, false), |
| std::move(subcase_preds), |
| std::move(subcase_results), |
| new ScalarLiteral(TypedValue(-1), int_type)), |
| new ScalarLiteral(TypedValue(0), int_type))); |
| when_predicates.emplace_back(second_case_predicate); |
| |
| result_expressions.emplace_back(new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(10), int_type))); |
| |
| // ELSE 5 * int_attr |
| ScalarCaseExpression case_expr( |
| TypeFactory::GetType(kInt, true), |
| std::move(when_predicates), |
| std::move(result_expressions), |
| new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kMultiply), |
| new ScalarLiteral(TypedValue(5), int_type), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)))); |
| |
| ColumnVectorPtr result_cv( |
| case_expr.getAllValues(&sample_data_value_accessor_, |
| nullptr /* sub_blocks_ref */, |
| nullptr /* cv_cache */)); |
| ASSERT_TRUE(result_cv->isNative()); |
| const NativeColumnVector &native_result_cv |
| = static_cast<const NativeColumnVector&>(*result_cv); |
| EXPECT_EQ(kNumSampleTuples, native_result_cv.size()); |
| |
| for (std::size_t tuple_num = 0; |
| tuple_num < native_result_cv.size(); |
| ++tuple_num) { |
| if ((tuple_num % 10 != 0) |
| && (tuple_num % 10 != 1) |
| && (tuple_num % 10 != 2) |
| && (tuple_num < 500) |
| && (tuple_num >= 200)) { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| EXPECT_EQ(-static_cast<int>(tuple_num), |
| *static_cast<const int*>(native_result_cv.getUntypedValue(tuple_num))); |
| } else if ((tuple_num % 10 != 0) |
| && (tuple_num % 10 != 2) |
| && (tuple_num >= 50 && tuple_num < 700)) { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| EXPECT_EQ(static_cast<int>(tuple_num) + 10, |
| *static_cast<const int*>(native_result_cv.getUntypedValue(tuple_num))); |
| } else if (tuple_num % 10 == 0) { |
| EXPECT_EQ(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| } else { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| EXPECT_EQ(static_cast<int>(tuple_num) * 5, |
| *static_cast<const int*>(native_result_cv.getUntypedValue(tuple_num))); |
| } |
| } |
| } |
| |
| // Same as above, but use a TupleIdSequenceAdapterValueAccessor that filters |
| // the input. |
| TEST_F(ScalarCaseExpressionTest, |
| ComplexConjunctionAndCalculatedExpressionWithFilteredInputTest) { |
| // Filter out every seventh tuple. |
| TupleIdSequence filter_sequence(kNumSampleTuples); |
| std::size_t num_filtered_tuples = 0; |
| for (std::size_t tuple_num = 0; |
| tuple_num < kNumSampleTuples; |
| ++tuple_num) { |
| if (tuple_num % 7 != 0) { |
| filter_sequence.set(tuple_num, true); |
| ++num_filtered_tuples; |
| } |
| } |
| |
| std::unique_ptr<ValueAccessor> filtered_accessor( |
| sample_data_value_accessor_.createSharedTupleIdSequenceAdapter(filter_sequence)); |
| |
| std::vector<std::unique_ptr<Predicate>> when_predicates; |
| std::vector<std::unique_ptr<Scalar>> result_expressions; |
| |
| const Type &int_type = TypeFactory::GetType(kInt); |
| const Type &varchar_type = TypeFactory::GetType(kVarChar, 20, false); |
| |
| // WHEN (int_attr * 2) + double_attr < 500 AND varchar_attr > 'aa002' |
| // THEN -int_attr |
| static const char kStringLit1[] = "aa002"; |
| ConjunctionPredicate *first_case_predicate = new ConjunctionPredicate(); |
| first_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd), |
| new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kMultiply), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(2), int_type)), |
| new ScalarAttribute(*sample_relation_->getAttributeById(1))), |
| new ScalarLiteral(TypedValue(static_cast<int>(kNumSampleTuples / 2)), int_type))); |
| first_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kGreater), |
| new ScalarAttribute(*sample_relation_->getAttributeById(2)), |
| new ScalarLiteral(TypedValue(kVarChar, kStringLit1, std::strlen(kStringLit1) + 1), varchar_type))); |
| when_predicates.emplace_back(first_case_predicate); |
| |
| result_expressions.emplace_back(new ScalarUnaryExpression( |
| UnaryOperationFactory::GetUnaryOperation(UnaryOperationID::kNegate), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)))); |
| |
| // WHEN int_attr >= 50 AND (CASE WHEN varchar_attr < 'aa007' THEN 1 ELSE -1) > 0 |
| // THEN int_attr + 10 |
| static const char kStringLit2[] = "aa007"; |
| ConjunctionPredicate *second_case_predicate = new ConjunctionPredicate(); |
| second_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kGreaterOrEqual), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(50), int_type))); |
| std::vector<std::unique_ptr<Predicate>> subcase_preds; |
| subcase_preds.emplace_back(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarAttribute(*sample_relation_->getAttributeById(2)), |
| new ScalarLiteral(TypedValue(kVarChar, kStringLit2, std::strlen(kStringLit2) + 1), varchar_type))); |
| std::vector<std::unique_ptr<Scalar>> subcase_results; |
| subcase_results.emplace_back(new ScalarLiteral(TypedValue(1), int_type)); |
| second_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kGreater), |
| new ScalarCaseExpression( |
| TypeFactory::GetType(kInt, false), |
| std::move(subcase_preds), |
| std::move(subcase_results), |
| new ScalarLiteral(TypedValue(-1), int_type)), |
| new ScalarLiteral(TypedValue(0), int_type))); |
| when_predicates.emplace_back(second_case_predicate); |
| |
| result_expressions.emplace_back(new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(10), int_type))); |
| |
| // ELSE 5 * int_attr |
| ScalarCaseExpression case_expr( |
| TypeFactory::GetType(kInt, true), |
| std::move(when_predicates), |
| std::move(result_expressions), |
| new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kMultiply), |
| new ScalarLiteral(TypedValue(5), int_type), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)))); |
| |
| ColumnVectorPtr result_cv( |
| case_expr.getAllValues(filtered_accessor.get(), |
| nullptr /* sub_blocks_ref */, |
| nullptr /* cv_cache */)); |
| ASSERT_TRUE(result_cv->isNative()); |
| const NativeColumnVector &native_result_cv |
| = static_cast<const NativeColumnVector&>(*result_cv); |
| EXPECT_EQ(num_filtered_tuples, native_result_cv.size()); |
| |
| std::size_t original_tuple_num = 0; |
| for (std::size_t tuple_num = 0; |
| tuple_num < native_result_cv.size(); |
| ++tuple_num, ++original_tuple_num) { |
| if (original_tuple_num % 7 == 0) { |
| ++original_tuple_num; |
| } |
| |
| if ((original_tuple_num % 10 != 0) |
| && (original_tuple_num % 10 != 1) |
| && (original_tuple_num % 10 != 2) |
| && (original_tuple_num < 500) |
| && (original_tuple_num >= 200)) { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| EXPECT_EQ(-static_cast<int>(original_tuple_num), |
| *static_cast<const int*>(native_result_cv.getUntypedValue(tuple_num))); |
| } else if ((original_tuple_num % 10 != 0) |
| && (original_tuple_num % 10 != 2) |
| && (original_tuple_num >= 50 && original_tuple_num < 700)) { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| EXPECT_EQ(static_cast<int>(original_tuple_num) + 10, |
| *static_cast<const int*>(native_result_cv.getUntypedValue(tuple_num))); |
| } else if (original_tuple_num % 10 == 0) { |
| EXPECT_EQ(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| } else { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| EXPECT_EQ(static_cast<int>(original_tuple_num) * 5, |
| *static_cast<const int*>(native_result_cv.getUntypedValue(tuple_num))); |
| } |
| } |
| } |
| |
| TEST_F(ScalarCaseExpressionTest, ComplexDisjunctionAndCalculatedExpressionTest) { |
| const Type &int_type = TypeFactory::GetType(kInt); |
| const Type &varchar_type = TypeFactory::GetType(kVarChar, 20, false); |
| |
| std::vector<std::unique_ptr<Predicate>> when_predicates; |
| std::vector<std::unique_ptr<Scalar>> result_expressions; |
| |
| // WHEN (int_attr * 2) + double_attr > 750 OR varchar_attr < 'aa002' |
| // THEN -int_attr |
| static const char kStringLit1[] = "aa002"; |
| DisjunctionPredicate *first_case_predicate = new DisjunctionPredicate(); |
| first_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kGreater), |
| new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd), |
| new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kMultiply), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(2), int_type)), |
| new ScalarAttribute(*sample_relation_->getAttributeById(1))), |
| new ScalarLiteral(TypedValue(750), int_type))); |
| first_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarAttribute(*sample_relation_->getAttributeById(2)), |
| new ScalarLiteral(TypedValue(kVarChar, kStringLit1, std::strlen(kStringLit1) + 1), varchar_type))); |
| when_predicates.emplace_back(first_case_predicate); |
| |
| result_expressions.emplace_back(new ScalarUnaryExpression( |
| UnaryOperationFactory::GetUnaryOperation(UnaryOperationID::kNegate), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)))); |
| |
| // WHEN int_attr < 350 OR (CASE WHEN varchar_attr < 'aa006' THEN 1 ELSE -1) < 0 |
| // THEN int_attr + 10 |
| static const char kStringLit2[] = "aa006"; |
| DisjunctionPredicate *second_case_predicate = new DisjunctionPredicate(); |
| second_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(350), int_type))); |
| std::vector<std::unique_ptr<Predicate>> subcase_preds; |
| subcase_preds.emplace_back(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarAttribute(*sample_relation_->getAttributeById(2)), |
| new ScalarLiteral(TypedValue(kVarChar, kStringLit2, std::strlen(kStringLit2) + 1), varchar_type))); |
| std::vector<std::unique_ptr<Scalar>> subcase_results; |
| subcase_results.emplace_back(new ScalarLiteral(TypedValue(1), int_type)); |
| second_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarCaseExpression( |
| TypeFactory::GetType(kInt, false), |
| std::move(subcase_preds), |
| std::move(subcase_results), |
| new ScalarLiteral(TypedValue(-1), int_type)), |
| new ScalarLiteral(TypedValue(0), int_type))); |
| when_predicates.emplace_back(second_case_predicate); |
| |
| result_expressions.emplace_back(new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(10), int_type))); |
| |
| // ELSE 5 * int_attr |
| ScalarCaseExpression case_expr( |
| TypeFactory::GetType(kInt, true), |
| std::move(when_predicates), |
| std::move(result_expressions), |
| new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kMultiply), |
| new ScalarLiteral(TypedValue(5), int_type), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)))); |
| |
| ColumnVectorPtr result_cv( |
| case_expr.getAllValues(&sample_data_value_accessor_, |
| nullptr /* sub_blocks_ref */, |
| nullptr /* cv_cache */)); |
| ASSERT_TRUE(result_cv->isNative()); |
| const NativeColumnVector &native_result_cv |
| = static_cast<const NativeColumnVector&>(*result_cv); |
| EXPECT_EQ(kNumSampleTuples, native_result_cv.size()); |
| |
| for (std::size_t tuple_num = 0; |
| tuple_num < native_result_cv.size(); |
| ++tuple_num) { |
| if (tuple_num % 10 == 0) { |
| EXPECT_EQ(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| } else if (((tuple_num % 10 != 1) && (tuple_num > 750)) |
| || ((tuple_num % 10 != 2) && (tuple_num < 200))) { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| EXPECT_EQ(-static_cast<int>(tuple_num), |
| *static_cast<const int*>(native_result_cv.getUntypedValue(tuple_num))); |
| } else if ((tuple_num < 350) |
| || (tuple_num % 10 == 2) |
| || ((tuple_num % 10 != 2) && (tuple_num >= 600))) { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| EXPECT_EQ(static_cast<int>(tuple_num) + 10, |
| *static_cast<const int*>(native_result_cv.getUntypedValue(tuple_num))); |
| } else { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| EXPECT_EQ(static_cast<int>(tuple_num) * 5, |
| *static_cast<const int*>(native_result_cv.getUntypedValue(tuple_num))); |
| } |
| } |
| } |
| |
| // Same as above, but use a TupleIdSequenceAdapterValueAccessor that filters |
| // the input. |
| TEST_F(ScalarCaseExpressionTest, |
| ComplexDisjunctionAndCalculatedExpressionWithFilteredInputTest) { |
| const Type &int_type = TypeFactory::GetType(kInt); |
| const Type &varchar_type = TypeFactory::GetType(kVarChar, 20, false); |
| |
| // Filter out every seventh tuple. |
| TupleIdSequence filter_sequence(kNumSampleTuples); |
| std::size_t num_filtered_tuples = 0; |
| for (std::size_t tuple_num = 0; |
| tuple_num < kNumSampleTuples; |
| ++tuple_num) { |
| if (tuple_num % 7 != 0) { |
| filter_sequence.set(tuple_num, true); |
| ++num_filtered_tuples; |
| } |
| } |
| |
| std::unique_ptr<ValueAccessor> filtered_accessor( |
| sample_data_value_accessor_.createSharedTupleIdSequenceAdapter(filter_sequence)); |
| |
| std::vector<std::unique_ptr<Predicate>> when_predicates; |
| std::vector<std::unique_ptr<Scalar>> result_expressions; |
| |
| // WHEN (int_attr * 2) + double_attr > 750 OR varchar_attr < 'aa002' |
| // THEN -int_attr |
| static const char kStringLit1[] = "aa002"; |
| DisjunctionPredicate *first_case_predicate = new DisjunctionPredicate(); |
| first_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kGreater), |
| new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd), |
| new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kMultiply), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(2), int_type)), |
| new ScalarAttribute(*sample_relation_->getAttributeById(1))), |
| new ScalarLiteral(TypedValue(750), int_type))); |
| first_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarAttribute(*sample_relation_->getAttributeById(2)), |
| new ScalarLiteral(TypedValue(kVarChar, kStringLit1, std::strlen(kStringLit1) + 1), varchar_type))); |
| when_predicates.emplace_back(first_case_predicate); |
| |
| result_expressions.emplace_back(new ScalarUnaryExpression( |
| UnaryOperationFactory::GetUnaryOperation(UnaryOperationID::kNegate), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)))); |
| |
| // WHEN int_attr < 350 OR (CASE WHEN varchar_attr < 'aa006' THEN 1 ELSE -1) < 0 |
| // THEN int_attr + 10 |
| static const char kStringLit2[] = "aa006"; |
| DisjunctionPredicate *second_case_predicate = new DisjunctionPredicate(); |
| second_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(350), int_type))); |
| std::vector<std::unique_ptr<Predicate>> subcase_preds; |
| subcase_preds.emplace_back(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarAttribute(*sample_relation_->getAttributeById(2)), |
| new ScalarLiteral(TypedValue(kVarChar, kStringLit2, std::strlen(kStringLit2) + 1), varchar_type))); |
| std::vector<std::unique_ptr<Scalar>> subcase_results; |
| subcase_results.emplace_back(new ScalarLiteral(TypedValue(1), int_type)); |
| second_case_predicate->addPredicate(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarCaseExpression( |
| TypeFactory::GetType(kInt, false), |
| std::move(subcase_preds), |
| std::move(subcase_results), |
| new ScalarLiteral(TypedValue(-1), int_type)), |
| new ScalarLiteral(TypedValue(0), int_type))); |
| when_predicates.emplace_back(second_case_predicate); |
| |
| result_expressions.emplace_back(new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarLiteral(TypedValue(10), int_type))); |
| |
| // ELSE 5 * int_attr |
| ScalarCaseExpression case_expr( |
| TypeFactory::GetType(kInt, true), |
| std::move(when_predicates), |
| std::move(result_expressions), |
| new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kMultiply), |
| new ScalarLiteral(TypedValue(5), int_type), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)))); |
| |
| ColumnVectorPtr result_cv( |
| case_expr.getAllValues(filtered_accessor.get(), |
| nullptr /* sub_blocks_ref */, |
| nullptr /* cv_cache */)); |
| ASSERT_TRUE(result_cv->isNative()); |
| const NativeColumnVector &native_result_cv |
| = static_cast<const NativeColumnVector&>(*result_cv); |
| EXPECT_EQ(num_filtered_tuples, native_result_cv.size()); |
| |
| std::size_t original_tuple_num = 0; |
| for (std::size_t tuple_num = 0; |
| tuple_num < native_result_cv.size(); |
| ++tuple_num, ++original_tuple_num) { |
| if (original_tuple_num % 7 == 0) { |
| ++original_tuple_num; |
| } |
| |
| if (original_tuple_num % 10 == 0) { |
| EXPECT_EQ(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| } else if (((original_tuple_num % 10 != 1) && (original_tuple_num > 750)) |
| || ((original_tuple_num % 10 != 2) && (original_tuple_num < 200))) { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| EXPECT_EQ(-static_cast<int>(original_tuple_num), |
| *static_cast<const int*>(native_result_cv.getUntypedValue(tuple_num))); |
| } else if ((original_tuple_num < 350) |
| || (original_tuple_num % 10 == 2) |
| || ((original_tuple_num % 10 != 2) && (original_tuple_num >= 600))) { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| EXPECT_EQ(static_cast<int>(original_tuple_num) + 10, |
| *static_cast<const int*>(native_result_cv.getUntypedValue(tuple_num))); |
| } else { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(tuple_num)); |
| EXPECT_EQ(static_cast<int>(original_tuple_num) * 5, |
| *static_cast<const int*>(native_result_cv.getUntypedValue(tuple_num))); |
| } |
| } |
| } |
| |
| // Test CASE evaluation over joins, with both WHEN predicates and THEN |
| // expressions referencing attributes in both relations. |
| TEST_F(ScalarCaseExpressionTest, JoinTest) { |
| // Simulate a join with another relation. |
| CatalogRelation other_relation(nullptr, "other", 1); |
| other_relation.addAttribute(new CatalogAttribute(&other_relation, |
| "other_double", |
| TypeFactory::GetType(kDouble, false))); |
| other_relation.addAttribute(new CatalogAttribute(&other_relation, |
| "other_int", |
| TypeFactory::GetType(kInt, false))); |
| |
| static const double kOtherDoubleValues[] = {-250.0, -750.0}; |
| std::unique_ptr<NativeColumnVector> other_double_column( |
| new NativeColumnVector(TypeFactory::GetType(kDouble, false), 2)); |
| other_double_column->appendUntypedValue(kOtherDoubleValues); |
| other_double_column->appendUntypedValue(kOtherDoubleValues + 1); |
| |
| static const int kOtherIntValues[] = {10, -1}; |
| std::unique_ptr<NativeColumnVector> other_int_column( |
| new NativeColumnVector(TypeFactory::GetType(kInt, false), 2)); |
| other_int_column->appendUntypedValue(kOtherIntValues); |
| other_int_column->appendUntypedValue(kOtherIntValues + 1); |
| |
| ColumnVectorsValueAccessor other_accessor; |
| other_accessor.addColumn(other_double_column.release()); |
| other_accessor.addColumn(other_int_column.release()); |
| |
| // Setup expression. |
| std::vector<std::unique_ptr<Predicate>> when_predicates; |
| std::vector<std::unique_ptr<Scalar>> result_expressions; |
| |
| // WHEN double_attr > other_double THEN int_attr + other_int |
| when_predicates.emplace_back(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kGreater), |
| new ScalarAttribute(*sample_relation_->getAttributeById(1)), |
| new ScalarAttribute(*other_relation.getAttributeById(0)))); |
| result_expressions.emplace_back(new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarAttribute(*other_relation.getAttributeById(1)))); |
| |
| // WHEN double_attr < other_double THEN int_attr * other_int |
| when_predicates.emplace_back(new ComparisonPredicate( |
| ComparisonFactory::GetComparison(ComparisonID::kLess), |
| new ScalarAttribute(*sample_relation_->getAttributeById(1)), |
| new ScalarAttribute(*other_relation.getAttributeById(0)))); |
| result_expressions.emplace_back(new ScalarBinaryExpression( |
| BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kMultiply), |
| new ScalarAttribute(*sample_relation_->getAttributeById(0)), |
| new ScalarAttribute(*other_relation.getAttributeById(1)))); |
| |
| // ELSE 0 |
| ScalarCaseExpression case_expr( |
| TypeFactory::GetType(kInt, true), |
| std::move(when_predicates), |
| std::move(result_expressions), |
| new ScalarLiteral(TypedValue(0), TypeFactory::GetType(kInt))); |
| |
| // Create a list of joined tuple-id pairs (just the cross-product of tuples). |
| std::vector<std::pair<tuple_id, tuple_id>> joined_tuple_ids; |
| for (std::size_t tuple_num = 0; tuple_num < kNumSampleTuples; ++tuple_num) { |
| joined_tuple_ids.emplace_back(tuple_num, 0); |
| joined_tuple_ids.emplace_back(tuple_num, 1); |
| } |
| |
| ColumnVectorPtr result_cv(case_expr.getAllValuesForJoin( |
| 0, |
| &sample_data_value_accessor_, |
| 1, |
| &other_accessor, |
| joined_tuple_ids, |
| nullptr /* cv_cache */)); |
| ASSERT_TRUE(result_cv->isNative()); |
| const NativeColumnVector &native_result_cv |
| = static_cast<const NativeColumnVector&>(*result_cv); |
| EXPECT_EQ(kNumSampleTuples * 2, native_result_cv.size()); |
| |
| for (std::size_t result_num = 0; |
| result_num < native_result_cv.size(); |
| ++result_num) { |
| // For convenience, calculate expected tuple values here. |
| const bool sample_int_null = ((result_num >> 1) % 10 == 0); |
| const int sample_int = result_num >> 1; |
| const bool sample_double_null = ((result_num >> 1) % 10 == 1); |
| const double sample_double = -static_cast<double>(result_num >> 1); |
| const int other_int = kOtherIntValues[result_num & 0x1]; |
| const double other_double = kOtherDoubleValues[result_num & 0x1]; |
| |
| if (sample_double_null) { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(result_num)); |
| EXPECT_EQ(0, *static_cast<const int*>(native_result_cv.getUntypedValue(result_num))); |
| } else if (sample_double > other_double) { |
| if (sample_int_null) { |
| EXPECT_EQ(nullptr, native_result_cv.getUntypedValue(result_num)); |
| } else { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(result_num)); |
| EXPECT_EQ(sample_int + other_int, |
| *static_cast<const int*>(native_result_cv.getUntypedValue(result_num))); |
| } |
| } else if (sample_double < other_double) { |
| if (sample_int_null) { |
| EXPECT_EQ(nullptr, native_result_cv.getUntypedValue(result_num)); |
| } else { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(result_num)); |
| EXPECT_EQ(sample_int * other_int, |
| *static_cast<const int*>(native_result_cv.getUntypedValue(result_num))); |
| } |
| } else { |
| ASSERT_NE(nullptr, native_result_cv.getUntypedValue(result_num)); |
| EXPECT_EQ(0, *static_cast<const int*>(native_result_cv.getUntypedValue(result_num))); |
| } |
| } |
| } |
| |
| } // namespace quickstep |