blob: 00a7710ae42499e9001b77d74ca37c8b450deab4 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
**/
#include "expressions/scalar/ScalarCaseExpression.hpp"
#include <cstddef>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "catalog/CatalogTypedefs.hpp"
#include "expressions/Expressions.pb.h"
#include "expressions/predicate/Predicate.hpp"
#include "expressions/scalar/Scalar.hpp"
#include "storage/TupleIdSequence.hpp"
#include "storage/ValueAccessor.hpp"
#include "storage/ValueAccessorUtil.hpp"
#include "types/Type.hpp"
#include "types/Type.pb.h"
#include "types/TypedValue.hpp"
#include "types/containers/ColumnVector.hpp"
#include "glog/logging.h"
namespace quickstep {
ScalarCaseExpression::ScalarCaseExpression(
const Type &result_type,
std::vector<std::unique_ptr<Predicate>> &&when_predicates,
std::vector<std::unique_ptr<Scalar>> &&result_expressions,
Scalar *else_result_expression)
: Scalar(result_type),
when_predicates_(std::move(when_predicates)),
result_expressions_(std::move(result_expressions)),
else_result_expression_(else_result_expression),
fixed_result_expression_(nullptr),
has_static_value_(false) {
DCHECK_EQ(when_predicates_.size(), result_expressions_.size());
DCHECK(else_result_expression_ != nullptr);
#ifdef QUICKSTEP_DEBUG
for (const std::unique_ptr<Scalar> &result_expr : result_expressions_) {
DCHECK(result_expr->getType().isSubsumedBy(type_));
}
#endif // QUICKSTEP_DEBUG
DCHECK(else_result_expression_->getType().isSubsumedBy(type_));
// Resolve case branch statically if possible.
bool static_case_branch = true;
for (std::vector<std::unique_ptr<Predicate>>::size_type case_idx = 0;
case_idx < when_predicates_.size();
++case_idx) {
if (when_predicates_[case_idx]->hasStaticResult()) {
if (when_predicates_[case_idx]->getStaticResult()) {
fixed_result_expression_ = result_expressions_[case_idx].get();
break;
}
} else {
static_case_branch = false;
}
}
if (static_case_branch && (fixed_result_expression_ == nullptr)) {
fixed_result_expression_ = else_result_expression_.get();
}
if (fixed_result_expression_ != nullptr) {
if (fixed_result_expression_->hasStaticValue()) {
has_static_value_ = true;
static_value_ = fixed_result_expression_->getStaticValue();
} else {
// CASE always goes to the same branch, but its value is not static.
}
}
}
serialization::Scalar ScalarCaseExpression::getProto() const {
serialization::Scalar proto;
proto.set_data_source(serialization::Scalar::CASE_EXPRESSION);
proto.MutableExtension(serialization::ScalarCaseExpression::result_type)
->CopyFrom(type_.getProto());
for (const std::unique_ptr<Predicate> &when_pred : when_predicates_) {
proto.AddExtension(serialization::ScalarCaseExpression::when_predicate)
->CopyFrom(when_pred->getProto());
}
for (const std::unique_ptr<Scalar> &result_expr : result_expressions_) {
proto.AddExtension(serialization::ScalarCaseExpression::result_expression)
->CopyFrom(result_expr->getProto());
}
proto.MutableExtension(serialization::ScalarCaseExpression::else_result_expression)
->CopyFrom(else_result_expression_->getProto());
return proto;
}
Scalar* ScalarCaseExpression::clone() const {
std::vector<std::unique_ptr<Predicate>> when_predicate_clones;
when_predicate_clones.reserve(when_predicates_.size());
for (const std::unique_ptr<Predicate> &when_pred : when_predicates_) {
when_predicate_clones.emplace_back(when_pred->clone());
}
std::vector<std::unique_ptr<Scalar>> result_expression_clones;
result_expression_clones.reserve(result_expressions_.size());
for (const std::unique_ptr<Scalar> &result_expr : result_expressions_) {
result_expression_clones.emplace_back(result_expr->clone());
}
return new ScalarCaseExpression(type_,
std::move(when_predicate_clones),
std::move(result_expression_clones),
else_result_expression_->clone());
}
TypedValue ScalarCaseExpression::getValueForSingleTuple(
const ValueAccessor &accessor,
const tuple_id tuple) const {
if (has_static_value_) {
return static_value_.makeReferenceToThis();
} else if (fixed_result_expression_ != nullptr) {
return fixed_result_expression_->getValueForSingleTuple(accessor, tuple);
} else {
for (std::vector<std::unique_ptr<Predicate>>::size_type case_idx = 0;
case_idx < when_predicates_.size();
++case_idx) {
if (when_predicates_[case_idx]->matchesForSingleTuple(accessor, tuple)) {
return result_expressions_[case_idx]->getValueForSingleTuple(accessor, tuple);
}
}
return else_result_expression_->getValueForSingleTuple(accessor, tuple);
}
}
TypedValue ScalarCaseExpression::getValueForJoinedTuples(
const ValueAccessor &left_accessor,
const relation_id left_relation_id,
const tuple_id left_tuple_id,
const ValueAccessor &right_accessor,
const relation_id right_relation_id,
const tuple_id right_tuple_id) const {
if (has_static_value_) {
return static_value_.makeReferenceToThis();
} else if (fixed_result_expression_ != nullptr) {
return fixed_result_expression_->getValueForJoinedTuples(left_accessor,
left_relation_id,
left_tuple_id,
right_accessor,
right_relation_id,
right_tuple_id);
} else {
for (std::vector<std::unique_ptr<Predicate>>::size_type case_idx = 0;
case_idx < when_predicates_.size();
++case_idx) {
if (when_predicates_[case_idx]->matchesForJoinedTuples(left_accessor,
left_relation_id,
left_tuple_id,
right_accessor,
right_relation_id,
right_tuple_id)) {
return result_expressions_[case_idx]->getValueForJoinedTuples(
left_accessor,
left_relation_id,
left_tuple_id,
right_accessor,
right_relation_id,
right_tuple_id);
}
}
return else_result_expression_->getValueForJoinedTuples(
left_accessor,
left_relation_id,
left_tuple_id,
right_accessor,
right_relation_id,
right_tuple_id);
}
}
ColumnVectorPtr ScalarCaseExpression::getAllValues(
ValueAccessor *accessor,
const SubBlocksReference *sub_blocks_ref,
ColumnVectorCache *cv_cache) const {
return InvokeOnValueAccessorMaybeTupleIdSequenceAdapter(
accessor,
[&](auto *accessor) -> ColumnVectorPtr { // NOLINT(build/c++11)
if (has_static_value_) {
return ColumnVectorPtr(
ColumnVector::MakeVectorOfValue(type_,
static_value_,
accessor->getNumTuples()));
} else if (fixed_result_expression_ != nullptr) {
return fixed_result_expression_->getAllValues(
accessor, sub_blocks_ref, cv_cache);
}
const TupleIdSequence *accessor_sequence = accessor->getTupleIdSequence();
// Initially set '*else_matches' to cover all tuples from the
// ValueAccessor.
std::unique_ptr<TupleIdSequence> else_matches;
if (accessor_sequence != nullptr) {
else_matches.reset(new TupleIdSequence(accessor_sequence->length()));
else_matches->assignFrom(*accessor_sequence);
} else {
else_matches.reset(new TupleIdSequence(accessor->getEndPosition()));
else_matches->setRange(0, else_matches->length(), true);
}
// Generate a TupleIdSequence of matches for each branch in the CASE.
std::vector<std::unique_ptr<TupleIdSequence>> case_matches;
for (std::vector<std::unique_ptr<Predicate>>::size_type case_idx = 0;
case_idx < when_predicates_.size();
++case_idx) {
if (else_matches->empty()) {
break;
}
case_matches.emplace_back(when_predicates_[case_idx]->getAllMatches(
accessor,
sub_blocks_ref,
else_matches.get(),
accessor_sequence));
else_matches->intersectWithComplement(*case_matches.back());
}
// Generate a ColumnVector of all the values for each case.
std::vector<ColumnVectorPtr> case_results;
for (std::vector<std::unique_ptr<TupleIdSequence>>::size_type case_idx = 0;
case_idx < case_matches.size();
++case_idx) {
std::unique_ptr<ValueAccessor> case_accessor(
accessor->createSharedTupleIdSequenceAdapter(*case_matches[case_idx]));
case_results.emplace_back(
result_expressions_[case_idx]->getAllValues(
case_accessor.get(), sub_blocks_ref, cv_cache));
}
ColumnVectorPtr else_results;
if (!else_matches->empty()) {
std::unique_ptr<ValueAccessor> else_accessor(
accessor->createSharedTupleIdSequenceAdapter(*else_matches));
else_results = else_result_expression_->getAllValues(
else_accessor.get(), sub_blocks_ref, cv_cache);
}
// Multiplex per-case results into a single ColumnVector with values in the
// correct positions.
return this->multiplexColumnVectors(
accessor->getNumTuples(),
accessor_sequence,
case_matches,
*else_matches,
case_results,
else_results);
});
}
ColumnVectorPtr ScalarCaseExpression::getAllValuesForJoin(
const relation_id left_relation_id,
ValueAccessor *left_accessor,
const relation_id right_relation_id,
ValueAccessor *right_accessor,
const std::vector<std::pair<tuple_id, tuple_id>> &joined_tuple_ids,
ColumnVectorCache *cv_cache) const {
// Slice 'joined_tuple_ids' apart by case.
//
// NOTE(chasseur): We use TupleIdSequence to keep track of the positions in
// 'joined_tuple_ids' that match for a particular case. This is a bit hacky,
// since TupleIdSequence is intended to represent tuple IDs in one block.
// All we're really using it for is to multiplex results into the right
// place.
//
// TODO(chasseur): Currently case predicates are evaluated tuple-at-a-time
// here (just like in a NestedLoopsJoin). If and when we implement vectorized
// evaluation of nested-loops predicates (or just filtration of joined IDs),
// we should use that here.
TupleIdSequence else_positions(joined_tuple_ids.size());
else_positions.setRange(0, joined_tuple_ids.size(), true);
std::vector<std::unique_ptr<TupleIdSequence>> case_positions;
std::vector<std::vector<std::pair<tuple_id, tuple_id>>> case_matches;
for (std::vector<std::unique_ptr<Predicate>>::size_type case_idx = 0;
case_idx < when_predicates_.size();
++case_idx) {
if (else_positions.empty()) {
break;
}
TupleIdSequence *current_case_positions = new TupleIdSequence(joined_tuple_ids.size());
case_positions.emplace_back(current_case_positions);
case_matches.resize(case_matches.size() + 1);
std::vector<std::pair<tuple_id, tuple_id>> &current_case_matches = case_matches.back();
const Predicate &case_predicate = *when_predicates_[case_idx];
for (tuple_id pos : else_positions) {
const std::pair<tuple_id, tuple_id> check_pair = joined_tuple_ids[pos];
if (case_predicate.matchesForJoinedTuples(*left_accessor,
left_relation_id,
check_pair.first,
*right_accessor,
right_relation_id,
check_pair.second)) {
current_case_positions->set(pos, true);
current_case_matches.emplace_back(check_pair);
}
}
else_positions.intersectWithComplement(*current_case_positions);
}
// Generate a ColumnVector of all the values for each case.
std::vector<ColumnVectorPtr> case_results;
for (std::vector<std::vector<std::pair<tuple_id, tuple_id>>>::size_type case_idx = 0;
case_idx < case_matches.size();
++case_idx) {
case_results.emplace_back(result_expressions_[case_idx]->getAllValuesForJoin(
left_relation_id,
left_accessor,
right_relation_id,
right_accessor,
case_matches[case_idx],
cv_cache));
}
ColumnVectorPtr else_results;
if (!else_positions.empty()) {
std::vector<std::pair<tuple_id, tuple_id>> else_matches;
for (tuple_id pos : else_positions) {
else_matches.emplace_back(joined_tuple_ids[pos]);
}
else_results = else_result_expression_->getAllValuesForJoin(
left_relation_id,
left_accessor,
right_relation_id,
right_accessor,
else_matches,
cv_cache);
}
// Multiplex per-case results into a single ColumnVector with values in the
// correct positions.
return multiplexColumnVectors(
joined_tuple_ids.size(),
nullptr,
case_positions,
else_positions,
case_results,
else_results);
}
void ScalarCaseExpression::MultiplexNativeColumnVector(
const TupleIdSequence *source_sequence,
const TupleIdSequence &case_matches,
const NativeColumnVector &case_result,
NativeColumnVector *output) {
if (source_sequence == nullptr) {
if (case_result.typeIsNullable()) {
TupleIdSequence::const_iterator output_pos_it = case_matches.begin();
for (std::size_t input_pos = 0;
input_pos < case_result.size();
++input_pos, ++output_pos_it) {
const void *value = case_result.getUntypedValue<true>(input_pos);
if (value == nullptr) {
output->positionalWriteNullValue(*output_pos_it);
} else {
output->positionalWriteUntypedValue(*output_pos_it, value);
}
}
} else {
TupleIdSequence::const_iterator output_pos_it = case_matches.begin();
for (std::size_t input_pos = 0;
input_pos < case_result.size();
++input_pos, ++output_pos_it) {
output->positionalWriteUntypedValue(*output_pos_it,
case_result.getUntypedValue<false>(input_pos));
}
}
} else {
if (case_result.typeIsNullable()) {
std::size_t input_pos = 0;
TupleIdSequence::const_iterator source_sequence_it = source_sequence->begin();
for (std::size_t output_pos = 0;
output_pos < output->size();
++output_pos, ++source_sequence_it) {
if (case_matches.get(*source_sequence_it)) {
const void *value = case_result.getUntypedValue<true>(input_pos++);
if (value == nullptr) {
output->positionalWriteNullValue(output_pos);
} else {
output->positionalWriteUntypedValue(output_pos, value);
}
}
}
} else {
std::size_t input_pos = 0;
TupleIdSequence::const_iterator source_sequence_it = source_sequence->begin();
for (std::size_t output_pos = 0;
output_pos < output->size();
++output_pos, ++source_sequence_it) {
if (case_matches.get(*source_sequence_it)) {
output->positionalWriteUntypedValue(output_pos,
case_result.getUntypedValue<false>(input_pos++));
}
}
}
}
}
void ScalarCaseExpression::MultiplexIndirectColumnVector(
const TupleIdSequence *source_sequence,
const TupleIdSequence &case_matches,
const IndirectColumnVector &case_result,
IndirectColumnVector *output) {
if (source_sequence == nullptr) {
TupleIdSequence::const_iterator output_pos_it = case_matches.begin();
for (std::size_t input_pos = 0;
input_pos < case_result.size();
++input_pos, ++output_pos_it) {
output->positionalWriteTypedValue(*output_pos_it,
case_result.getTypedValue(input_pos));
}
} else {
std::size_t input_pos = 0;
TupleIdSequence::const_iterator source_sequence_it = source_sequence->begin();
for (std::size_t output_pos = 0;
output_pos < output->size();
++output_pos, ++source_sequence_it) {
if (case_matches.get(*source_sequence_it)) {
output->positionalWriteTypedValue(output_pos,
case_result.getTypedValue(input_pos++));
}
}
}
}
ColumnVectorPtr ScalarCaseExpression::multiplexColumnVectors(
const std::size_t output_size,
const TupleIdSequence *source_sequence,
const std::vector<std::unique_ptr<TupleIdSequence>> &case_matches,
const TupleIdSequence &else_matches,
const std::vector<ColumnVectorPtr> &case_results,
const ColumnVectorPtr &else_result) const {
DCHECK_EQ(case_matches.size(), case_results.size());
if (NativeColumnVector::UsableForType(type_)) {
std::unique_ptr<NativeColumnVector> native_result(
new NativeColumnVector(type_, output_size));
native_result->prepareForPositionalWrites();
for (std::vector<std::unique_ptr<TupleIdSequence>>::size_type case_idx = 0;
case_idx < case_matches.size();
++case_idx) {
DCHECK(case_results[case_idx]->isNative());
if (!case_matches[case_idx]->empty()) {
MultiplexNativeColumnVector(
source_sequence,
*case_matches[case_idx],
static_cast<const NativeColumnVector&>(*case_results[case_idx]),
native_result.get());
}
}
if (else_result != nullptr) {
DCHECK(else_result->isNative());
DCHECK(!else_matches.empty());
MultiplexNativeColumnVector(source_sequence,
else_matches,
static_cast<const NativeColumnVector&>(*else_result),
native_result.get());
}
return ColumnVectorPtr(native_result.release());
} else {
std::unique_ptr<IndirectColumnVector> indirect_result(
new IndirectColumnVector(type_, output_size));
indirect_result->prepareForPositionalWrites();
for (std::vector<std::unique_ptr<TupleIdSequence>>::size_type case_idx = 0;
case_idx < case_matches.size();
++case_idx) {
DCHECK(!case_results[case_idx]->isNative());
if (!case_matches[case_idx]->empty()) {
MultiplexIndirectColumnVector(
source_sequence,
*case_matches[case_idx],
static_cast<const IndirectColumnVector&>(*case_results[case_idx]),
indirect_result.get());
}
}
if (else_result != nullptr) {
DCHECK(!else_result->isNative());
DCHECK(!else_matches.empty());
MultiplexIndirectColumnVector(source_sequence,
else_matches,
static_cast<const IndirectColumnVector&>(*else_result),
indirect_result.get());
}
return ColumnVectorPtr(indirect_result.release());
}
}
void ScalarCaseExpression::getFieldStringItems(
std::vector<std::string> *inline_field_names,
std::vector<std::string> *inline_field_values,
std::vector<std::string> *non_container_child_field_names,
std::vector<const Expression*> *non_container_child_fields,
std::vector<std::string> *container_child_field_names,
std::vector<std::vector<const Expression*>> *container_child_fields) const {
Scalar::getFieldStringItems(inline_field_names,
inline_field_values,
non_container_child_field_names,
non_container_child_fields,
container_child_field_names,
container_child_fields);
if (has_static_value_) {
inline_field_names->emplace_back("static_value");
if (static_value_.isNull()) {
inline_field_values->emplace_back("NULL");
} else {
inline_field_values->emplace_back(type_.printValueToString(static_value_));
}
}
container_child_field_names->emplace_back("when_predicates");
container_child_fields->emplace_back();
for (const auto &predicate : when_predicates_) {
container_child_fields->back().emplace_back(predicate.get());
}
container_child_field_names->emplace_back("result_expressions");
container_child_fields->emplace_back();
for (const auto &expression : result_expressions_) {
container_child_fields->back().emplace_back(expression.get());
}
if (else_result_expression_ != nullptr) {
non_container_child_field_names->emplace_back("else_result_expression");
non_container_child_fields->emplace_back(else_result_expression_.get());
}
}
} // namespace quickstep