blob: ba44b7925fc793e013bbcdf8fcae5d96c3adcd48 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <Rewriter/RelRewriter.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>
namespace local_engine
{
/// need to avoid conflict with spark builtin functions
enum SelfDefinedFunctionReference
{
GET_JSON_OBJECT = 1000000,
};
/// Collect all get_json_object functions and group by json strings.
/// Rewrite the get_json_object functions into flattenJSONStringOnRequired + sparktupleElement. This
/// could avoid repeated parsing the same json string and save a lot of time.
class GetJsonObjectFunctionWriter : public RelRewriter
{
public:
GetJsonObjectFunctionWriter(ParserContextPtr parser_context_) : RelRewriter(parser_context_) { }
~GetJsonObjectFunctionWriter() override = default;
void rewrite(substrait::Rel & rel) override
{
prepare(rel);
rewriteImpl(rel);
}
private:
std::unordered_map<String, std::set<String>> json_required_fields;
/// Collect all get_json_object functions and group by json strings
void prepare(const substrait::Rel & rel)
{
if (rel.has_filter())
{
auto & expr = rel.filter().condition();
prepareOnExpression(expr);
}
if (rel.has_project())
{
for (auto & expr : rel.project().expressions())
{
prepareOnExpression(expr);
}
}
if (rel.has_generate())
{
for (auto & expr : rel.generate().child_output())
{
prepareOnExpression(expr);
}
prepareOnExpression(rel.generate().generator());
}
}
void rewriteImpl(substrait::Rel & rel)
{
if (rel.has_filter())
{
auto * filter = rel.mutable_filter();
auto * expression = filter->mutable_condition();
rewriteExpression(*expression);
}
if (rel.has_project())
{
auto * project = rel.mutable_project();
auto * exprssions = project->mutable_expressions();
for (int i = 0; i < project->expressions_size(); ++i)
{
auto * expr = exprssions->Mutable(i);
rewriteExpression(*expr);
}
}
if (rel.has_generate())
{
auto * generate = rel.mutable_generate();
auto * child_outputs = generate->mutable_child_output();
for (int i = 0; i < child_outputs->size(); ++i)
{
auto * expr = child_outputs->Mutable(i);
rewriteExpression(*expr);
}
rewriteExpression(*generate->mutable_generator());
}
}
void prepareOnExpression(const substrait::Expression & expr)
{
switch (expr.rex_type_case())
{
case substrait::Expression::RexTypeCase::kCast: {
prepareOnExpression(expr.cast().input());
break;
}
case substrait::Expression::RexTypeCase::kIfThen: {
const auto & if_then = expr.if_then();
auto condition_nums = if_then.ifs_size();
for (int i = 0; i < condition_nums; ++i)
{
prepareOnExpression(if_then.ifs(i).if_());
prepareOnExpression(if_then.ifs(i).then());
}
prepareOnExpression(if_then.else_());
break;
}
case substrait::Expression::RexTypeCase::kSingularOrList: {
prepareOnExpression(expr.singular_or_list().value());
break;
}
case substrait::Expression::RexTypeCase::kScalarFunction: {
const auto & scalar_function_pb = expr.scalar_function();
auto function_signature_name_opt = parser_context->getFunctionNameInSignature(scalar_function_pb);
if (!function_signature_name_opt)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknow scalar function: {}", scalar_function_pb.DebugString());
auto function_signature_name = *function_signature_name_opt;
for (const auto & arg : scalar_function_pb.arguments())
{
if (arg.has_value())
{
prepareOnExpression(arg.value());
}
}
if (function_signature_name == "get_json_object")
{
auto json_key = scalar_function_pb.arguments(0).DebugString();
if (!json_required_fields.count(json_key))
{
json_required_fields[json_key] = std::set<String>();
}
auto & required_fields = json_required_fields.at(json_key);
auto json_path_pb = scalar_function_pb.arguments(1).value();
if (!json_path_pb.has_literal() || !json_path_pb.literal().has_string())
{
break;
}
required_fields.emplace(json_path_pb.literal().string());
}
break;
}
default:
break;
}
}
void rewriteExpression(substrait::Expression & expr)
{
switch (expr.rex_type_case())
{
case substrait::Expression::RexTypeCase::kCast: {
if (expr.cast().has_input())
rewriteExpression(*expr.mutable_cast()->mutable_input());
break;
}
case substrait::Expression::RexTypeCase::kIfThen: {
auto * if_then = expr.mutable_if_then();
auto condition_nums = if_then->ifs_size();
auto * ifs = if_then->mutable_ifs();
for (int i = 0; i < condition_nums; ++i)
{
rewriteExpression(*ifs->Mutable(i)->mutable_if_());
rewriteExpression(*ifs->Mutable(i)->mutable_then());
}
rewriteExpression(*if_then->mutable_else_());
break;
}
case substrait::Expression::RexTypeCase::kSingularOrList: {
rewriteExpression(*expr.mutable_singular_or_list()->mutable_value());
break;
}
case substrait::Expression::RexTypeCase::kScalarFunction: {
auto & scalar_function_pb = *expr.mutable_scalar_function();
if (scalar_function_pb.arguments().empty())
break;
auto json_key = scalar_function_pb.arguments(0).DebugString();
auto function_signature_name_opt = parser_context->getFunctionNameInSignature(scalar_function_pb);
if (!function_signature_name_opt)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknow scalar function: {}", scalar_function_pb.DebugString());
auto function_signature_name = *function_signature_name_opt;
for (auto & arg : *scalar_function_pb.mutable_arguments())
{
if (arg.has_value())
{
rewriteExpression(*arg.mutable_value());
}
}
if (function_signature_name == "get_json_object")
{
if (!json_required_fields.count(json_key))
{
/// This is not expected, but it still could work.
LOG_ERROR(&Poco::Logger::get("GetJsonObjectFunctionWriter"), "Cannot find json key {}", json_key);
break;
}
auto & required_fields = json_required_fields.at(json_key);
if (required_fields.empty())
{
break;
}
auto json_path_pb = scalar_function_pb.arguments(1).value();
if (!json_path_pb.has_literal() || !json_path_pb.literal().has_string())
{
break;
}
String required_fields_str;
int i = 0;
for (const auto & field : required_fields)
{
if (i)
{
required_fields_str += "|";
}
required_fields_str += field;
i += 1;
}
substrait::Expression_ScalarFunction decoded_json_function;
decoded_json_function.set_function_reference(SelfDefinedFunctionReference::GET_JSON_OBJECT);
decoded_json_function.mutable_output_type()->CopyFrom(buildReturnType(required_fields));
auto * arg0 = decoded_json_function.add_arguments();
arg0->CopyFrom(scalar_function_pb.arguments(0));
auto * arg1 = decoded_json_function.add_arguments();
arg1->mutable_value()->mutable_literal()->set_string(required_fields_str);
substrait::Expression new_get_json_object_arg0;
new_get_json_object_arg0.mutable_scalar_function()->CopyFrom(decoded_json_function);
*scalar_function_pb.mutable_arguments()->Mutable(0)->mutable_value() = new_get_json_object_arg0;
}
break;
}
default:
break;
}
}
substrait::Type buildReturnType(const std::set<std::string> & fields)
{
substrait::Type_Struct st;
for (const auto & field : fields)
{
st.add_names(field);
substrait::Type_String str_type;
str_type.set_nullability(substrait::Type_Nullability::Type_Nullability_NULLABILITY_NULLABLE);
st.add_types()->mutable_string()->CopyFrom(str_type);
}
substrait::Type res;
res.mutable_struct_()->CopyFrom(st);
return res;
}
};
class ExpressionsRewriter
{
public:
explicit ExpressionsRewriter(ParserContextPtr parser_context_) : parser_context(parser_context_) { }
void rewrite(substrait::Rel & rel)
{
GetJsonObjectFunctionWriter get_json_object_rewriter(parser_context);
get_json_object_rewriter.rewrite(rel);
}
private:
ParserContextPtr parser_context;
};
}