blob: 734e1499425e80659be4ff426db120603746e3dc [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <algorithm>
#include <chrono>
#include <cstddef>
#include <cstdint>
#include <optional>
#include <type_traits>
#include <utility>
#include <variant>
#include <gmock/gmock.h>
#include <gtest/gtest-matchers.h>
#include <gtest/gtest.h>
#include "arrow/buffer.h"
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/exec/asof_join_node.h"
#include "arrow/compute/exec/exec_plan.h"
#include "arrow/compute/exec/expression.h"
#include "arrow/compute/exec/expression_internal.h"
#include "arrow/compute/exec/options.h"
#include "arrow/compute/exec/test_util.h"
#include "arrow/compute/exec/util.h"
#include "arrow/compute/registry.h"
#include "arrow/compute/type_fwd.h"
#include "arrow/dataset/dataset.h"
#include "arrow/dataset/discovery.h"
#include "arrow/dataset/file_base.h"
#include "arrow/dataset/file_ipc.h"
#include "arrow/dataset/partition.h"
#include "arrow/dataset/plan.h"
#include "arrow/dataset/scanner.h"
#include "arrow/datum.h"
#include "arrow/engine/substrait/extension_set.h"
#include "arrow/engine/substrait/extension_types.h"
#include "arrow/engine/substrait/options.h"
#include "arrow/engine/substrait/serde.h"
#include "arrow/engine/substrait/util.h"
#include "arrow/filesystem/filesystem.h"
#include "arrow/filesystem/localfs.h"
#include "arrow/filesystem/mockfs.h"
#include "arrow/filesystem/test_util.h"
#include "arrow/io/type_fwd.h"
#include "arrow/ipc/options.h"
#include "arrow/ipc/writer.h"
#include "arrow/scalar.h"
#include "arrow/table.h"
#include "arrow/testing/future_util.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/matchers.h"
#include "arrow/type.h"
#include "arrow/type_fwd.h"
#include "arrow/util/async_generator_fwd.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/decimal.h"
#include "arrow/util/future.h"
#include "arrow/util/hash_util.h"
#include "arrow/util/io_util.h"
#include "arrow/util/iterator.h"
#include "arrow/util/key_value_metadata.h"
using testing::ElementsAre;
using testing::Eq;
using testing::HasSubstr;
using testing::UnorderedElementsAre;
namespace arrow {
using internal::checked_cast;
using internal::hash_combine;
namespace engine {
void WriteIpcData(const std::string& path,
const std::shared_ptr<fs::FileSystem> file_system,
const std::shared_ptr<Table> input) {
EXPECT_OK_AND_ASSIGN(auto mmap, file_system->OpenOutputStream(path));
ASSERT_OK_AND_ASSIGN(
auto file_writer,
MakeFileWriter(mmap, input->schema(), ipc::IpcWriteOptions::Defaults()));
TableBatchReader reader(input);
std::shared_ptr<RecordBatch> batch;
while (true) {
ASSERT_OK(reader.ReadNext(&batch));
if (batch == nullptr) {
break;
}
ASSERT_OK(file_writer->WriteRecordBatch(*batch));
}
ASSERT_OK(file_writer->Close());
}
const auto kNullConsumer = std::make_shared<compute::NullSinkNodeConsumer>();
const std::shared_ptr<Schema> kBoringSchema = schema({
field("bool", boolean()),
field("i8", int8()),
field("i32", int32()),
field("i32_req", int32(), /*nullable=*/false),
field("u32", uint32()),
field("i64", int64()),
field("f32", float32()),
field("f32_req", float32(), /*nullable=*/false),
field("f64", float64()),
field("date64", date64()),
field("str", utf8()),
field("list_i32", list(int32())),
field("struct", struct_({
field("i32", int32()),
field("str", utf8()),
field("struct_i32_str",
struct_({field("i32", int32()), field("str", utf8())})),
})),
field("list_struct", list(struct_({
field("i32", int32()),
field("str", utf8()),
field("struct_i32_str", struct_({field("i32", int32()),
field("str", utf8())})),
}))),
field("dict_str", dictionary(int32(), utf8())),
field("dict_i32", dictionary(int32(), int32())),
field("ts_ns", timestamp(TimeUnit::NANO)),
});
std::shared_ptr<DataType> StripFieldNames(std::shared_ptr<DataType> type) {
if (type->id() == Type::STRUCT) {
FieldVector fields(type->num_fields());
for (int i = 0; i < type->num_fields(); ++i) {
fields[i] = type->field(i)->WithName("");
}
return struct_(std::move(fields));
}
if (type->id() == Type::LIST) {
return list(type->field(0)->WithName(""));
}
return type;
}
inline compute::Expression UseBoringRefs(const compute::Expression& expr) {
if (expr.literal()) return expr;
if (auto ref = expr.field_ref()) {
return compute::field_ref(*ref->FindOne(*kBoringSchema));
}
auto modified_call = *CallNotNull(expr);
for (auto& arg : modified_call.arguments) {
arg = UseBoringRefs(arg);
}
return compute::Expression{std::move(modified_call)};
}
void CheckRoundTripResult(const std::shared_ptr<Table> expected_table,
std::shared_ptr<Buffer>& buf,
const std::vector<int>& include_columns = {},
const ConversionOptions& conversion_options = {},
const compute::SortOptions* sort_options = NULLPTR) {
std::shared_ptr<ExtensionIdRegistry> sp_ext_id_reg = MakeExtensionIdRegistry();
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_OK_AND_ASSIGN(auto sink_decls, DeserializePlans(
*buf, [] { return kNullConsumer; },
ext_id_reg, &ext_set, conversion_options));
auto& other_declrs = std::get<compute::Declaration>(sink_decls[0].inputs[0]);
ASSERT_OK_AND_ASSIGN(auto output_table,
compute::DeclarationToTable(other_declrs, /*use_threads=*/false));
if (!include_columns.empty()) {
ASSERT_OK_AND_ASSIGN(output_table, output_table->SelectColumns(include_columns));
}
if (sort_options) {
ASSERT_OK_AND_ASSIGN(auto sort_indices,
SortIndices(output_table, std::move(*sort_options)));
ASSERT_OK_AND_ASSIGN(auto maybe_table,
compute::Take(output_table, std::move(sort_indices),
compute::TakeOptions::NoBoundsCheck()));
output_table = maybe_table.table();
}
ASSERT_OK_AND_ASSIGN(output_table, output_table->CombineChunks());
ASSERT_OK_AND_ASSIGN(auto merged_expected, expected_table->CombineChunks());
compute::AssertTablesEqualIgnoringOrder(merged_expected, output_table);
}
int CountProjectNodeOptionsInDeclarations(const compute::Declaration& input) {
int counter = 0;
if (input.factory_name == "project") {
counter++;
}
const auto& inputs = input.inputs;
for (const auto& in : inputs) {
counter += CountProjectNodeOptionsInDeclarations(std::get<compute::Declaration>(in));
}
return counter;
}
/// Validate the number of expected ProjectNodes
///
/// Project nodes are sometimes added by emit elements and we may want to
/// verify that we are not adding too many
void ValidateNumProjectNodes(int expected_projections, const std::shared_ptr<Buffer>& buf,
const ConversionOptions& conversion_options) {
std::shared_ptr<ExtensionIdRegistry> sp_ext_id_reg = MakeExtensionIdRegistry();
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_OK_AND_ASSIGN(auto sink_decls, DeserializePlans(
*buf, [] { return kNullConsumer; },
ext_id_reg, &ext_set, conversion_options));
auto& other_declrs = std::get<compute::Declaration>(sink_decls[0].inputs[0]);
int num_projections = CountProjectNodeOptionsInDeclarations(other_declrs);
ASSERT_EQ(num_projections, expected_projections);
}
TEST(Substrait, SupportedTypes) {
auto ExpectEq = [](std::string_view json, std::shared_ptr<DataType> expected_type) {
ARROW_SCOPED_TRACE(json);
ExtensionSet empty;
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON(
"Type", json, /*ignore_unknown_fields=*/false));
ASSERT_OK_AND_ASSIGN(auto type, DeserializeType(*buf, empty));
EXPECT_EQ(*type, *expected_type);
ASSERT_OK_AND_ASSIGN(auto serialized, SerializeType(*type, &empty));
EXPECT_EQ(empty.num_types(), 0);
// FIXME chokes on NULLABILITY_UNSPECIFIED
// EXPECT_THAT(internal::CheckMessagesEquivalent("Type", *buf, *serialized), Ok());
ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeType(*serialized, empty));
EXPECT_EQ(*roundtripped, *expected_type);
};
ExpectEq(R"({"bool": {}})", boolean());
ExpectEq(R"({"i8": {}})", int8());
ExpectEq(R"({"i16": {}})", int16());
ExpectEq(R"({"i32": {}})", int32());
ExpectEq(R"({"i64": {}})", int64());
ExpectEq(R"({"fp32": {}})", float32());
ExpectEq(R"({"fp64": {}})", float64());
ExpectEq(R"({"string": {}})", utf8());
ExpectEq(R"({"binary": {}})", binary());
ExpectEq(R"({"timestamp": {}})", timestamp(TimeUnit::MICRO));
ExpectEq(R"({"date": {}})", date32());
ExpectEq(R"({"time": {}})", time64(TimeUnit::MICRO));
ExpectEq(R"({"timestamp_tz": {}})", timestamp(TimeUnit::MICRO, "UTC"));
ExpectEq(R"({"interval_year": {}})", interval_year());
ExpectEq(R"({"interval_day": {}})", interval_day());
ExpectEq(R"({"uuid": {}})", uuid());
ExpectEq(R"({"fixed_char": {"length": 32}})", fixed_char(32));
ExpectEq(R"({"varchar": {"length": 1024}})", varchar(1024));
ExpectEq(R"({"fixed_binary": {"length": 32}})", fixed_size_binary(32));
ExpectEq(R"({"decimal": {"precision": 27, "scale": 5}})", decimal128(27, 5));
ExpectEq(R"({"struct": {
"types": [
{"i64": {}},
{"list": {"type": {"string":{}} }}
]
}})",
struct_({
field("", int64()),
field("", list(utf8())),
}));
ExpectEq(R"({"map": {
"key": {"string":{"nullability": "NULLABILITY_REQUIRED"}},
"value": {"string":{}}
}})",
map(utf8(), field("", utf8()), false));
}
TEST(Substrait, SupportedExtensionTypes) {
ExtensionSet ext_set;
for (auto expected_type : {
null(),
uint8(),
uint16(),
uint32(),
uint64(),
}) {
auto anchor = ext_set.num_types();
EXPECT_THAT(ext_set.EncodeType(*expected_type), ResultWith(Eq(anchor)));
ASSERT_OK_AND_ASSIGN(
auto buf,
internal::SubstraitFromJSON(
"Type",
"{\"user_defined\": { \"type_reference\": " + std::to_string(anchor) +
", \"nullability\": \"NULLABILITY_NULLABLE\" } }",
/*ignore_unknown_fields=*/false));
ASSERT_OK_AND_ASSIGN(auto type, DeserializeType(*buf, ext_set));
EXPECT_EQ(*type, *expected_type);
auto size = ext_set.num_types();
ASSERT_OK_AND_ASSIGN(auto serialized, SerializeType(*type, &ext_set));
EXPECT_EQ(ext_set.num_types(), size) << "was already added to the set above";
ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeType(*serialized, ext_set));
EXPECT_EQ(*roundtripped, *expected_type);
}
}
TEST(Substrait, NamedStruct) {
ExtensionSet ext_set;
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("NamedStruct", R"({
"struct": {
"types": [
{"i64": {}},
{"list": {"type": {"string":{}} }},
{"struct": {
"types": [
{"fp32": {"nullability": "NULLABILITY_REQUIRED"}},
{"string": {}}
]
}},
{"list": {"type": {"string":{}} }},
]
},
"names": ["a", "b", "c", "d", "e", "f"]
})",
/*ignore_unknown_fields=*/false));
ASSERT_OK_AND_ASSIGN(auto schema, DeserializeSchema(*buf, ext_set));
Schema expected_schema({
field("a", int64()),
field("b", list(utf8())),
field("c", struct_({
field("d", float32(), /*nullable=*/false),
field("e", utf8()),
})),
field("f", list(utf8())),
});
EXPECT_EQ(*schema, expected_schema);
ASSERT_OK_AND_ASSIGN(auto serialized, SerializeSchema(*schema, &ext_set));
ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeSchema(*serialized, ext_set));
EXPECT_EQ(*roundtripped, expected_schema);
// too few names
ASSERT_OK_AND_ASSIGN(buf, internal::SubstraitFromJSON("NamedStruct", R"({
"struct": {"types": [{"i32": {}}, {"i32": {}}, {"i32": {}}]},
"names": []
})",
/*ignore_unknown_fields=*/false));
EXPECT_THAT(DeserializeSchema(*buf, ext_set), Raises(StatusCode::Invalid));
// too many names
ASSERT_OK_AND_ASSIGN(buf, internal::SubstraitFromJSON("NamedStruct", R"({
"struct": {"types": []},
"names": ["a", "b", "c"]
})",
/*ignore_unknown_fields=*/false));
EXPECT_THAT(DeserializeSchema(*buf, ext_set), Raises(StatusCode::Invalid));
ConversionOptions conversion_options;
conversion_options.strictness = ConversionStrictness::EXACT_ROUNDTRIP;
// no schema metadata allowed with EXACT_ROUNDTRIP
EXPECT_THAT(SerializeSchema(Schema({}, key_value_metadata({{"ext", "yes"}})), &ext_set,
conversion_options),
Raises(StatusCode::Invalid));
ASSERT_OK(SerializeSchema(Schema({}, key_value_metadata({{"ext", "yes"}})), &ext_set));
// no field metadata allowed with EXACT_ROUNDTRIP
EXPECT_THAT(
SerializeSchema(Schema({field("a", int32(), key_value_metadata({{"ext", "yes"}}))}),
&ext_set, conversion_options),
Raises(StatusCode::Invalid));
}
TEST(Substrait, NoEquivalentArrowType) {
ASSERT_OK_AND_ASSIGN(
auto buf,
internal::SubstraitFromJSON("Type", R"({"user_defined": {"type_reference": 99}})",
/*ignore_unknown_fields=*/false));
ExtensionSet empty;
ASSERT_THAT(
DeserializeType(*buf, empty),
Raises(StatusCode::Invalid, HasSubstr("did not have a corresponding anchor")));
}
TEST(Substrait, NoEquivalentSubstraitType) {
for (auto type : {
date64(),
timestamp(TimeUnit::SECOND),
timestamp(TimeUnit::NANO),
timestamp(TimeUnit::MICRO, "New York"),
time32(TimeUnit::SECOND),
time32(TimeUnit::MILLI),
time64(TimeUnit::NANO),
decimal256(76, 67),
sparse_union({field("i8", int8()), field("f32", float32())}),
dense_union({field("i8", int8()), field("f32", float32())}),
dictionary(int32(), utf8()),
fixed_size_list(float16(), 3),
duration(TimeUnit::MICRO),
large_utf8(),
large_binary(),
large_list(utf8()),
}) {
ARROW_SCOPED_TRACE(type->ToString());
ExtensionSet set;
EXPECT_THAT(SerializeType(*type, &set), Raises(StatusCode::NotImplemented));
}
}
TEST(Substrait, SupportedLiterals) {
auto ExpectEq = [](std::string_view json, Datum expected_value) {
ARROW_SCOPED_TRACE(json);
for (bool nullable : {false, true}) {
std::string json_with_nullable;
if (nullable) {
auto final_closing_brace = json.find_last_of('}');
ASSERT_NE(std::string_view::npos, final_closing_brace);
json_with_nullable =
std::string(json.substr(0, final_closing_brace)) + ", \"nullable\": true}";
json = json_with_nullable;
}
ASSERT_OK_AND_ASSIGN(
auto buf, internal::SubstraitFromJSON("Expression",
"{\"literal\":" + std::string(json) + "}",
/*ignore_unknown_fields=*/false));
ExtensionSet ext_set;
ASSERT_OK_AND_ASSIGN(auto expr, DeserializeExpression(*buf, ext_set));
ASSERT_TRUE(expr.literal());
ASSERT_THAT(*expr.literal(), DataEq(expected_value));
ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set));
EXPECT_EQ(ext_set.num_functions(),
0); // shouldn't need extensions for core literals
ASSERT_OK_AND_ASSIGN(auto roundtripped,
DeserializeExpression(*serialized, ext_set));
ASSERT_TRUE(roundtripped.literal());
ASSERT_THAT(*roundtripped.literal(), DataEq(expected_value));
}
};
ExpectEq(R"({"boolean": true})", Datum(true));
ExpectEq(R"({"i8": 34})", Datum(int8_t(34)));
ExpectEq(R"({"i16": 34})", Datum(int16_t(34)));
ExpectEq(R"({"i32": 34})", Datum(int32_t(34)));
ExpectEq(R"({"i64": "34"})", Datum(int64_t(34)));
ExpectEq(R"({"fp32": 3.5})", Datum(3.5F));
ExpectEq(R"({"fp64": 7.125})", Datum(7.125));
ExpectEq(R"({"string": "hello world"})", Datum("hello world"));
ExpectEq(R"({"binary": "enp6"})", BinaryScalar(Buffer::FromString("zzz")));
ExpectEq(R"({"timestamp": "579"})", TimestampScalar(579, TimeUnit::MICRO));
ExpectEq(R"({"date": "5"})", Date32Scalar(5));
ExpectEq(R"({"time": "64"})", Time64Scalar(64, TimeUnit::MICRO));
ExpectEq(R"({"interval_year_to_month": {"years": 34, "months": 3}})",
ExtensionScalar(FixedSizeListScalar(ArrayFromJSON(int32(), "[34, 3]")),
interval_year()));
ExpectEq(R"({"interval_day_to_second": {"days": 34, "seconds": 3}})",
ExtensionScalar(FixedSizeListScalar(ArrayFromJSON(int32(), "[34, 3]")),
interval_day()));
ExpectEq(R"({"fixed_char": "zzz"})",
ExtensionScalar(
FixedSizeBinaryScalar(Buffer::FromString("zzz"), fixed_size_binary(3)),
fixed_char(3)));
ExpectEq(R"({"var_char": {"value": "zzz", "length": 1024}})",
ExtensionScalar(StringScalar("zzz"), varchar(1024)));
ExpectEq(R"({"fixed_binary": "enp6"})",
FixedSizeBinaryScalar(Buffer::FromString("zzz"), fixed_size_binary(3)));
ExpectEq(
R"({"decimal": {"value": "0gKWSQAAAAAAAAAAAAAAAA==", "precision": 27, "scale": 5}})",
Decimal128Scalar(Decimal128("123456789.0"), decimal128(27, 5)));
ExpectEq(R"({"timestamp_tz": "579"})", TimestampScalar(579, TimeUnit::MICRO, "UTC"));
// special case for empty lists
ExpectEq(R"({"empty_list": {"type": {"i32": {}}}})",
ScalarFromJSON(list(int32()), "[]"));
ExpectEq(R"({"struct": {
"fields": [
{"i64": "32"},
{"list": {"values": [
{"string": "hello"},
{"string": "world"}
]}}
]
}})",
ScalarFromJSON(struct_({
field("", int64()),
field("", list(utf8())),
}),
R"([32, ["hello", "world"]])"));
// check null scalars:
for (auto type : {
boolean(),
int8(),
int64(),
timestamp(TimeUnit::MICRO),
interval_year(),
struct_({
field("", int64()),
field("", list(utf8())),
}),
}) {
ExtensionSet set;
ASSERT_OK_AND_ASSIGN(auto buf, SerializeType(*type, &set));
ASSERT_OK_AND_ASSIGN(auto json, internal::SubstraitToJSON("Type", *buf));
ExpectEq("{\"null\": " + json + "}", MakeNullScalar(type));
}
}
TEST(Substrait, CannotDeserializeLiteral) {
ExtensionSet ext_set;
// Invalid: missing List.element_type
ASSERT_OK_AND_ASSIGN(
auto buf, internal::SubstraitFromJSON("Expression",
R"({"literal": {"list": {"values": []}}})",
/*ignore_unknown_fields=*/false));
EXPECT_THAT(DeserializeExpression(*buf, ext_set), Raises(StatusCode::Invalid));
// Invalid: required null literal if in strict mode
ConversionOptions conversion_options;
conversion_options.strictness = ConversionStrictness::EXACT_ROUNDTRIP;
ASSERT_OK_AND_ASSIGN(
buf,
internal::SubstraitFromJSON(
"Expression",
R"({"literal": {"null": {"bool": {"nullability": "NULLABILITY_REQUIRED"}}}})",
/*ignore_unknown_fields=*/false));
EXPECT_THAT(DeserializeExpression(*buf, ext_set, conversion_options),
Raises(StatusCode::Invalid));
// no equivalent arrow scalar
// FIXME no way to specify scalars of user_defined_type_reference
}
TEST(Substrait, FieldRefRoundTrip) {
for (FieldRef ref : {
// by name
FieldRef("i32"),
FieldRef("ts_ns"),
FieldRef("struct"),
// by index
FieldRef(0),
FieldRef(1),
FieldRef(kBoringSchema->num_fields() - 1),
FieldRef(kBoringSchema->GetFieldIndex("struct")),
// nested
FieldRef("struct", "i32"),
FieldRef("struct", "struct_i32_str", "i32"),
FieldRef(kBoringSchema->GetFieldIndex("struct"), 1),
}) {
ARROW_SCOPED_TRACE(ref.ToString());
ASSERT_OK_AND_ASSIGN(auto expr, compute::field_ref(ref).Bind(*kBoringSchema));
ExtensionSet ext_set;
ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set));
EXPECT_EQ(ext_set.num_functions(),
0); // shouldn't need extensions for core field references
ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeExpression(*serialized, ext_set));
ASSERT_TRUE(roundtripped.field_ref());
ASSERT_OK_AND_ASSIGN(auto expected, ref.FindOne(*kBoringSchema));
ASSERT_OK_AND_ASSIGN(auto actual, roundtripped.field_ref()->FindOne(*kBoringSchema));
EXPECT_EQ(actual.indices(), expected.indices());
}
}
TEST(Substrait, RecursiveFieldRef) {
FieldRef ref("struct", "str");
ARROW_SCOPED_TRACE(ref.ToString());
ASSERT_OK_AND_ASSIGN(auto expr, compute::field_ref(ref).Bind(*kBoringSchema));
ExtensionSet ext_set;
ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set));
ASSERT_OK_AND_ASSIGN(auto expected,
internal::SubstraitFromJSON("Expression", R"({
"selection": {
"directReference": {
"structField": {
"field": 12,
"child": {
"structField": {
"field": 1
}
}
}
},
"rootReference": {}
}
})",
/*ignore_unknown_fields=*/false));
ASSERT_OK(internal::CheckMessagesEquivalent("Expression", *serialized, *expected));
}
TEST(Substrait, FieldRefsInExpressions) {
ASSERT_OK_AND_ASSIGN(auto expr,
compute::call("struct_field",
{compute::call("if_else",
{
compute::literal(true),
compute::field_ref("struct"),
compute::field_ref("struct"),
})},
compute::StructFieldOptions({0}))
.Bind(*kBoringSchema));
ExtensionSet ext_set;
ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set));
ASSERT_OK_AND_ASSIGN(auto expected,
internal::SubstraitFromJSON("Expression", R"({
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"expression": {
"if_then": {
"ifs": [
{
"if": {"literal": {"boolean": true}},
"then": {"selection": {"directReference": {"structField": {"field": 12}}}}
}
],
"else": {"selection": {"directReference": {"structField": {"field": 12}}}}
}
}
}
})",
/*ignore_unknown_fields=*/false));
ASSERT_OK(internal::CheckMessagesEquivalent("Expression", *serialized, *expected));
}
TEST(Substrait, CallSpecialCaseRoundTrip) {
for (compute::Expression expr : {
compute::call("if_else",
{
compute::literal(true),
compute::field_ref({"struct", 1}),
compute::field_ref("str"),
}),
compute::call(
"case_when",
{
compute::call("make_struct",
{compute::literal(false), compute::literal(true)},
compute::MakeStructOptions({"cond1", "cond2"})),
compute::field_ref({"struct", "str"}),
compute::field_ref({"struct", "struct_i32_str", "str"}),
compute::field_ref("str"),
}),
compute::call("list_element",
{
compute::field_ref("list_i32"),
compute::literal(3),
}),
compute::call("struct_field",
{compute::call("list_element",
{
compute::field_ref("list_struct"),
compute::literal(42),
})},
arrow::compute::StructFieldOptions({1})),
compute::call("struct_field",
{compute::call("list_element",
{
compute::field_ref("list_struct"),
compute::literal(42),
})},
arrow::compute::StructFieldOptions({2, 0})),
compute::call("struct_field",
{compute::call("if_else",
{
compute::literal(true),
compute::field_ref("struct"),
compute::field_ref("struct"),
})},
compute::StructFieldOptions({0})),
}) {
ARROW_SCOPED_TRACE(expr.ToString());
ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema));
ExtensionSet ext_set;
ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set));
// These are special cased as core expressions in substrait; shouldn't require any
// extensions.
EXPECT_EQ(ext_set.num_functions(), 0);
ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeExpression(*serialized, ext_set));
ASSERT_OK_AND_ASSIGN(roundtripped, roundtripped.Bind(*kBoringSchema));
EXPECT_EQ(UseBoringRefs(roundtripped), UseBoringRefs(expr));
}
}
TEST(Substrait, CallExtensionFunction) {
for (compute::Expression expr : {
compute::call("add", {compute::literal(0), compute::literal(1)}),
}) {
ARROW_SCOPED_TRACE(expr.ToString());
ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema));
ExtensionSet ext_set;
ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set));
// These require an extension, so we should have a single-element ext_set.
EXPECT_EQ(ext_set.num_functions(), 1);
ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeExpression(*serialized, ext_set));
ASSERT_OK_AND_ASSIGN(roundtripped, roundtripped.Bind(*kBoringSchema));
EXPECT_EQ(UseBoringRefs(roundtripped), UseBoringRefs(expr));
}
}
TEST(Substrait, ReadRel) {
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Rel", R"({
"read": {
"base_schema": {
"struct": {
"types": [ {"i64": {}}, {"bool": {}} ]
},
"names": ["i", "b"]
},
"filter": {
"selection": {
"directReference": {
"structField": {
"field": 1
}
}
}
},
"local_files": {
"items": [
{
"uri_file": "file:///tmp/dat1.parquet",
"parquet": {}
},
{
"uri_file": "file:///tmp/dat2.parquet",
"parquet": {}
}
]
}
}
})",
/*ignore_unknown_fields=*/false));
ExtensionSet ext_set;
ASSERT_OK_AND_ASSIGN(auto rel, DeserializeRelation(*buf, ext_set));
// converting a ReadRel produces a scan Declaration
ASSERT_EQ(rel.factory_name, "scan");
const auto& scan_node_options =
checked_cast<const dataset::ScanNodeOptions&>(*rel.options);
// filter on the boolean field (#1)
EXPECT_EQ(scan_node_options.scan_options->filter, compute::field_ref(1));
// dataset is a FileSystemDataset in parquet format with the specified schema
ASSERT_EQ(scan_node_options.dataset->type_name(), "filesystem");
const auto& dataset =
checked_cast<const dataset::FileSystemDataset&>(*scan_node_options.dataset);
EXPECT_THAT(dataset.files(),
UnorderedElementsAre("/tmp/dat1.parquet", "/tmp/dat2.parquet"));
EXPECT_EQ(dataset.format()->type_name(), "parquet");
EXPECT_EQ(*dataset.schema(), Schema({field("i", int64()), field("b", boolean())}));
}
/// \brief Create a NamedTableProvider that provides `table` regardless of the name
NamedTableProvider AlwaysProvideSameTable(std::shared_ptr<Table> table) {
return [table = std::move(table)](const std::vector<std::string>&) {
std::shared_ptr<compute::ExecNodeOptions> options =
std::make_shared<compute::TableSourceNodeOptions>(table);
return compute::Declaration("table_source", {}, options, "mock_source");
};
}
TEST(Substrait, RelWithHint) {
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Rel", R"({
"read": {
"common": {
"hint": {
"stats": {
"row_count": 1
}
},
"direct": { }
},
"base_schema": {
"struct": {
"types": [ {"i64": {}}, {"bool": {}} ]
},
"names": ["i", "b"]
},
"named_table": { "names": [ "foo" ] }
}
})",
/*ignore_unknown_fields=*/false));
ConversionOptions conversion_options;
conversion_options.named_table_provider = AlwaysProvideSameTable(nullptr);
ExtensionSet ext_set;
ASSERT_OK_AND_ASSIGN(auto rel, DeserializeRelation(*buf, ext_set, conversion_options));
conversion_options.strictness = ConversionStrictness::EXACT_ROUNDTRIP;
ASSERT_RAISES(NotImplemented, DeserializeRelation(*buf, ext_set, conversion_options));
}
TEST(Substrait, ExtensionSetFromPlan) {
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [
{"rel": {
"read": {
"base_schema": {
"struct": {
"types": [ {"i64": {}}, {"bool": {}} ]
},
"names": ["i", "b"]
},
"local_files": { "items": [] }
}
}}
],
"extension_uris": [
{
"extension_uri_anchor": 7,
"uri": ")" + default_extension_types_uri() +
R"("
},
{
"extension_uri_anchor": 18,
"uri": ")" + kSubstraitArithmeticFunctionsUri +
R"("
}
],
"extensions": [
{"extension_type": {
"extension_uri_reference": 7,
"type_anchor": 42,
"name": "null"
}},
{"extension_function": {
"extension_uri_reference": 18,
"function_anchor": 42,
"name": "add"
}}
]
})";
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
for (auto sp_ext_id_reg :
{std::shared_ptr<ExtensionIdRegistry>(), MakeExtensionIdRegistry()}) {
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_OK_AND_ASSIGN(auto sink_decls,
DeserializePlans(
*buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set));
EXPECT_OK_AND_ASSIGN(auto decoded_null_type, ext_set.DecodeType(42));
EXPECT_EQ(decoded_null_type.id.uri, kArrowExtTypesUri);
EXPECT_EQ(decoded_null_type.id.name, "null");
EXPECT_EQ(*decoded_null_type.type, NullType());
EXPECT_OK_AND_ASSIGN(Id decoded_add_func_id, ext_set.DecodeFunction(42));
EXPECT_EQ(decoded_add_func_id.uri, kSubstraitArithmeticFunctionsUri);
EXPECT_EQ(decoded_add_func_id.name, "add");
}
}
TEST(Substrait, ExtensionSetFromPlanMissingFunc) {
std::string substrait_json = R"({
"relations": [],
"extension_uris": [
{
"extension_uri_anchor": 7,
"uri": ")" + default_extension_types_uri() +
R"("
}
],
"extensions": [
{"extension_function": {
"extension_uri_reference": 7,
"function_anchor": 42,
"name": "does_not_exist"
}}
]
})";
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
for (auto sp_ext_id_reg :
{std::shared_ptr<ExtensionIdRegistry>(), MakeExtensionIdRegistry()}) {
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
// Since the function is not referenced this plan is ok unless we are asking for
// strict conversion.
ConversionOptions options;
options.strictness = ConversionStrictness::EXACT_ROUNDTRIP;
ASSERT_RAISES(Invalid,
DeserializePlans(
*buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set, options));
}
}
TEST(Substrait, ExtensionSetFromPlanExhaustedFactory) {
std::string substrait_json = R"({
"relations": [
{"rel": {
"read": {
"base_schema": {
"struct": {
"types": [ {"i64": {}}, {"bool": {}} ]
},
"names": ["i", "b"]
},
"local_files": { "items": [] }
}
}}
],
"extension_uris": [
{
"extension_uri_anchor": 7,
"uri": ")" + default_extension_types_uri() +
R"("
}
],
"extensions": [
{"extension_function": {
"extension_uri_reference": 7,
"function_anchor": 42,
"name": "add"
}}
]
})";
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
for (auto sp_ext_id_reg :
{std::shared_ptr<ExtensionIdRegistry>(), MakeExtensionIdRegistry()}) {
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_RAISES(
Invalid,
DeserializePlans(
*buf, []() -> std::shared_ptr<compute::SinkNodeConsumer> { return nullptr; },
ext_id_reg, &ext_set));
ASSERT_RAISES(
Invalid,
DeserializePlans(
*buf, []() -> std::shared_ptr<dataset::WriteNodeOptions> { return nullptr; },
ext_id_reg, &ext_set));
}
}
TEST(Substrait, ExtensionSetFromPlanRegisterFunc) {
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [],
"extension_uris": [
{
"extension_uri_anchor": 7,
"uri": ")" + default_extension_types_uri() +
R"("
}
],
"extensions": [
{"extension_function": {
"extension_uri_reference": 7,
"function_anchor": 42,
"name": "new_func"
}}
]
})";
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
auto sp_ext_id_reg = MakeExtensionIdRegistry();
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
// invalid before registration
ExtensionSet ext_set_invalid(ext_id_reg);
ConversionOptions conversion_options;
conversion_options.strictness = ConversionStrictness::EXACT_ROUNDTRIP;
ASSERT_RAISES(Invalid, DeserializePlans(
*buf, [] { return kNullConsumer; }, ext_id_reg,
&ext_set_invalid, conversion_options));
ASSERT_OK(ext_id_reg->AddSubstraitCallToArrow(
{default_extension_types_uri(), "new_func"}, "multiply"));
// valid after registration
ExtensionSet ext_set_valid(ext_id_reg);
ASSERT_OK_AND_ASSIGN(auto sink_decls,
DeserializePlans(
*buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set_valid,
conversion_options));
EXPECT_OK_AND_ASSIGN(Id decoded_add_func_id, ext_set_valid.DecodeFunction(42));
EXPECT_EQ(decoded_add_func_id.uri, kArrowExtTypesUri);
EXPECT_EQ(decoded_add_func_id.name, "new_func");
}
Result<std::string> GetSubstraitJSON() {
ARROW_ASSIGN_OR_RAISE(std::string dir_string,
arrow::internal::GetEnvVar("PARQUET_TEST_DATA"));
auto file_name =
arrow::internal::PlatformFilename::FromString(dir_string)->Join("binary.parquet");
auto file_path = file_name->ToString();
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [
{"rel": {
"read": {
"base_schema": {
"struct": {
"types": [
{"binary": {}}
]
},
"names": [
"foo"
]
},
"local_files": {
"items": [
{
"uri_file": "file://FILENAME_PLACEHOLDER",
"parquet": {}
}
]
}
}
}}
]
})";
std::string filename_placeholder = "FILENAME_PLACEHOLDER";
substrait_json.replace(substrait_json.find(filename_placeholder),
filename_placeholder.size(), file_path);
return substrait_json;
}
TEST(Substrait, DeserializeWithConsumerFactory) {
ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON());
ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json));
ASSERT_OK_AND_ASSIGN(auto declarations,
DeserializePlans(*buf, compute::NullSinkNodeConsumer::Make));
ASSERT_EQ(declarations.size(), 1);
compute::Declaration* decl = &declarations[0];
ASSERT_EQ(decl->factory_name, "consuming_sink");
ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make());
ASSERT_OK_AND_ASSIGN(auto sink_node, declarations[0].AddToPlan(plan.get()));
ASSERT_STREQ(sink_node->kind_name(), "ConsumingSinkNode");
ASSERT_EQ(sink_node->num_inputs(), 1);
auto& prev_node = sink_node->inputs()[0];
ASSERT_STREQ(prev_node->kind_name(), "SourceNode");
ASSERT_OK(plan->StartProducing());
ASSERT_FINISHES_OK(plan->finished());
}
TEST(Substrait, DeserializeSinglePlanWithConsumerFactory) {
ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON());
ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json));
ASSERT_OK_AND_ASSIGN(std::shared_ptr<compute::ExecPlan> plan,
DeserializePlan(*buf, compute::NullSinkNodeConsumer::Make()));
ASSERT_EQ(1, plan->sinks().size());
compute::ExecNode* sink_node = plan->sinks()[0];
ASSERT_STREQ(sink_node->kind_name(), "ConsumingSinkNode");
ASSERT_EQ(sink_node->num_inputs(), 1);
auto& prev_node = sink_node->inputs()[0];
ASSERT_STREQ(prev_node->kind_name(), "SourceNode");
ASSERT_OK(plan->StartProducing());
ASSERT_FINISHES_OK(plan->finished());
}
TEST(Substrait, DeserializeWithWriteOptionsFactory) {
dataset::internal::Initialize();
fs::TimePoint mock_now = std::chrono::system_clock::now();
fs::FileInfo testdir = ::arrow::fs::Dir("testdir");
ASSERT_OK_AND_ASSIGN(std::shared_ptr<fs::FileSystem> fs,
fs::internal::MockFileSystem::Make(mock_now, {testdir}));
auto write_options_factory = [&fs] {
std::shared_ptr<dataset::IpcFileFormat> format =
std::make_shared<dataset::IpcFileFormat>();
dataset::FileSystemDatasetWriteOptions options;
options.file_write_options = format->DefaultWriteOptions();
options.filesystem = fs;
options.basename_template = "chunk-{i}.arrow";
options.base_dir = "testdir";
options.partitioning =
std::make_shared<dataset::DirectoryPartitioning>(arrow::schema({}));
return std::make_shared<dataset::WriteNodeOptions>(options);
};
ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON());
ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json));
ASSERT_OK_AND_ASSIGN(auto declarations, DeserializePlans(*buf, write_options_factory));
ASSERT_EQ(declarations.size(), 1);
compute::Declaration* decl = &declarations[0];
ASSERT_EQ(decl->factory_name, "write");
ASSERT_EQ(decl->inputs.size(), 1);
decl = std::get_if<compute::Declaration>(&decl->inputs[0]);
ASSERT_NE(decl, nullptr);
ASSERT_EQ(decl->factory_name, "scan");
ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make());
ASSERT_OK_AND_ASSIGN(auto sink_node, declarations[0].AddToPlan(plan.get()));
ASSERT_STREQ(sink_node->kind_name(), "ConsumingSinkNode");
ASSERT_EQ(sink_node->num_inputs(), 1);
auto& prev_node = sink_node->inputs()[0];
ASSERT_STREQ(prev_node->kind_name(), "SourceNode");
ASSERT_OK(plan->StartProducing());
ASSERT_FINISHES_OK(plan->finished());
}
static void test_with_registries(
std::function<void(ExtensionIdRegistry*, compute::FunctionRegistry*)> test) {
auto default_func_reg = compute::GetFunctionRegistry();
auto nested_ext_id_reg = MakeExtensionIdRegistry();
auto nested_func_reg = compute::FunctionRegistry::Make(default_func_reg);
test(nullptr, default_func_reg);
test(nullptr, nested_func_reg.get());
test(nested_ext_id_reg.get(), default_func_reg);
test(nested_ext_id_reg.get(), nested_func_reg.get());
}
TEST(Substrait, GetRecordBatchReader) {
ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON());
test_with_registries([&substrait_json](ExtensionIdRegistry* ext_id_reg,
compute::FunctionRegistry* func_registry) {
ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json));
ASSERT_OK_AND_ASSIGN(auto reader, ExecuteSerializedPlan(*buf));
ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatchReader(reader.get()));
// Note: assuming the binary.parquet file contains fixed amount of records
// in case of a test failure, re-evalaute the content in the file
EXPECT_EQ(table->num_rows(), 12);
});
}
TEST(Substrait, InvalidPlan) {
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [
]
})";
test_with_registries([&substrait_json](ExtensionIdRegistry* ext_id_reg,
compute::FunctionRegistry* func_registry) {
ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json));
ASSERT_RAISES(Invalid, ExecuteSerializedPlan(*buf));
});
}
TEST(Substrait, InvalidMinimumVersion) {
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
"version": { "major_number": 0, "minor_number": 18, "patch_number": 0 },
"relations": [{
"rel": {
"read": {
"base_schema": {
"names": ["A"],
"struct": {
"types": [{
"i32": {}
}]
}
},
"named_table": { "names": ["x"] }
}
}
}],
"extensionUris": [],
"extensions": [],
})"));
ASSERT_RAISES(Invalid, DeserializePlans(*buf, [] { return kNullConsumer; }));
}
TEST(Substrait, JoinPlanBasic) {
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"join": {
"left": {
"read": {
"base_schema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"local_files": {
"items": [
{
"uri_file": "file:///tmp/dat1.parquet",
"parquet": {}
}
]
}
}
},
"right": {
"read": {
"base_schema": {
"names": ["X", "Y", "A"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"local_files": {
"items": [
{
"uri_file": "file:///tmp/dat2.parquet",
"parquet": {}
}
]
}
}
},
"expression": {
"scalarFunction": {
"functionReference": 0,
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}
}, {
"value": {
"selection": {
"directReference": {
"structField": {
"field": 5
}
},
"rootReference": {
}
}
}
}],
"output_type": {
"bool": {}
}
}
},
"type": "JOIN_TYPE_INNER"
}
}
}],
"extension_uris": [
{
"extension_uri_anchor": 0,
"uri": ")" + std::string(kSubstraitComparisonFunctionsUri) +
R"("
}
],
"extensions": [
{"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "equal"
}}
]
})";
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
for (auto sp_ext_id_reg :
{std::shared_ptr<ExtensionIdRegistry>(), MakeExtensionIdRegistry()}) {
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_OK_AND_ASSIGN(auto sink_decls,
DeserializePlans(
*buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set));
auto join_decl = sink_decls[0].inputs[0];
const auto& join_rel = std::get<compute::Declaration>(join_decl);
const auto& join_options =
checked_cast<const compute::HashJoinNodeOptions&>(*join_rel.options);
EXPECT_EQ(join_rel.factory_name, "hashjoin");
EXPECT_EQ(join_options.join_type, compute::JoinType::INNER);
const auto& left_rel = std::get<compute::Declaration>(join_rel.inputs[0]);
const auto& right_rel = std::get<compute::Declaration>(join_rel.inputs[1]);
const auto& l_options =
checked_cast<const dataset::ScanNodeOptions&>(*left_rel.options);
const auto& r_options =
checked_cast<const dataset::ScanNodeOptions&>(*right_rel.options);
AssertSchemaEqual(
l_options.dataset->schema(),
schema({field("A", int32()), field("B", int32()), field("C", int32())}));
AssertSchemaEqual(
r_options.dataset->schema(),
schema({field("X", int32()), field("Y", int32()), field("A", int32())}));
EXPECT_EQ(join_options.key_cmp[0], compute::JoinKeyCmp::EQ);
}
}
TEST(Substrait, JoinPlanInvalidKeyCmp) {
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"join": {
"left": {
"read": {
"base_schema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"local_files": {
"items": [
{
"uri_file": "file:///tmp/dat1.parquet",
"parquet": {}
}
]
}
}
},
"right": {
"read": {
"base_schema": {
"names": ["X", "Y", "A"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"local_files": {
"items": [
{
"uri_file": "file:///tmp/dat2.parquet",
"parquet": {}
}
]
}
}
},
"expression": {
"scalarFunction": {
"functionReference": 0,
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}
}, {
"value": {
"selection": {
"directReference": {
"structField": {
"field": 5
}
},
"rootReference": {
}
}
}
}],
"output_type": {
"bool": {}
}
}
},
"type": "JOIN_TYPE_INNER"
}
}
}],
"extension_uris": [
{
"extension_uri_anchor": 0,
"uri": ")" + std::string(kSubstraitArithmeticFunctionsUri) +
R"("
}
],
"extensions": [
{"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "add"
}}
]
})";
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
for (auto sp_ext_id_reg :
{std::shared_ptr<ExtensionIdRegistry>(), MakeExtensionIdRegistry()}) {
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_RAISES(Invalid, DeserializePlans(
*buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set));
}
}
TEST(Substrait, JoinPlanInvalidExpression) {
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"join": {
"left": {
"read": {
"base_schema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"local_files": {
"items": [
{
"uri_file": "file:///tmp/dat1.parquet",
"parquet": {}
}
]
}
}
},
"right": {
"read": {
"base_schema": {
"names": ["X", "Y", "A"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"local_files": {
"items": [
{
"uri_file": "file:///tmp/dat2.parquet",
"parquet": {}
}
]
}
}
},
"expression": {"literal": {"list": {"values": []}}},
"type": "JOIN_TYPE_INNER"
}
}
}]
})",
/*ignore_unknown_fields=*/false));
for (auto sp_ext_id_reg :
{std::shared_ptr<ExtensionIdRegistry>(), MakeExtensionIdRegistry()}) {
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_RAISES(Invalid, DeserializePlans(
*buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set));
}
}
TEST(Substrait, JoinPlanInvalidKeys) {
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"join": {
"left": {
"read": {
"base_schema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"local_files": {
"items": [
{
"uri_file": "file:///tmp/dat1.parquet",
"parquet": {}
}
]
}
}
},
"expression": {
"scalarFunction": {
"functionReference": 0,
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}
}, {
"value": {
"selection": {
"directReference": {
"structField": {
"field": 5
}
},
"rootReference": {
}
}
}
}]
}
},
"type": "JOIN_TYPE_INNER"
}
}
}]
})"));
for (auto sp_ext_id_reg :
{std::shared_ptr<ExtensionIdRegistry>(), MakeExtensionIdRegistry()}) {
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_RAISES(Invalid, DeserializePlans(
*buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set));
}
}
TEST(Substrait, AggregateBasic) {
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"aggregate": {
"input": {
"read": {
"base_schema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"local_files": {
"items": [
{
"uri_file": "file:///tmp/dat.parquet",
"parquet": {}
}
]
}
}
},
"groupings": [{
"groupingExpressions": [{
"selection": {
"directReference": {
"structField": {
"field": 0
}
}
}
}]
}],
"measures": [{
"measure": {
"functionReference": 0,
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 1
}
}
}
}
}],
"sorts": [],
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
"outputType": {
"i64": {}
}
}
}]
}
}
}],
"extensionUris": [{
"extension_uri_anchor": 0,
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
}],
"extensions": [{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "sum"
}
}],
})",
/*ignore_unknown_fields=*/false));
auto sp_ext_id_reg = MakeExtensionIdRegistry();
ASSERT_OK_AND_ASSIGN(auto sink_decls,
DeserializePlans(*buf, [] { return kNullConsumer; }));
auto agg_decl = sink_decls[0].inputs[0];
const auto& agg_rel = std::get<compute::Declaration>(agg_decl);
const auto& agg_options =
checked_cast<const compute::AggregateNodeOptions&>(*agg_rel.options);
EXPECT_EQ(agg_rel.factory_name, "aggregate");
EXPECT_EQ(agg_options.aggregates[0].name, "");
EXPECT_EQ(agg_options.aggregates[0].function, "hash_sum");
}
TEST(Substrait, AggregateInvalidRel) {
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"aggregate": {
}
}
}],
"extensionUris": [{
"extension_uri_anchor": 0,
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
}],
"extensions": [{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "sum"
}
}],
})",
/*ignore_unknown_fields=*/false));
ASSERT_RAISES(Invalid, DeserializePlans(*buf, [] { return kNullConsumer; }));
}
TEST(Substrait, AggregateInvalidFunction) {
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"aggregate": {
"input": {
"read": {
"base_schema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"local_files": {
"items": [
{
"uri_file": "file:///tmp/dat.parquet",
"parquet": {}
}
]
}
}
},
"groupings": [{
"groupingExpressions": [{
"selection": {
"directReference": {
"structField": {
"field": 0
}
}
}
}]
}],
"measures": [{
}]
}
}
}],
"extensionUris": [{
"extension_uri_anchor": 0,
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
}],
"extensions": [{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "sum"
}
}],
})",
/*ignore_unknown_fields=*/false));
ASSERT_RAISES(Invalid, DeserializePlans(*buf, [] { return kNullConsumer; }));
}
TEST(Substrait, AggregateInvalidAggFuncArgs) {
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"aggregate": {
"input": {
"read": {
"base_schema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"local_files": {
"items": [
{
"uri_file": "file:///tmp/dat.parquet",
"parquet": {}
}
]
}
}
},
"groupings": [{
"groupingExpressions": [{
"selection": {
"directReference": {
"structField": {
"field": 0
}
}
}
}]
}],
"measures": [{
"measure": {
"functionReference": 0,
"arguments": [],
"sorts": [],
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
"invocation": "AGGREGATION_INVOCATION_ALL",
"outputType": {
"i64": {}
}
}
}]
}
}
}],
"extensionUris": [{
"extension_uri_anchor": 0,
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
}],
"extensions": [{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "sum"
}
}],
})",
/*ignore_unknown_fields=*/false));
ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; }));
}
TEST(Substrait, AggregateWithFilter) {
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"aggregate": {
"input": {
"read": {
"base_schema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"local_files": {
"items": [
{
"uri_file": "file:///tmp/dat.parquet",
"parquet": {}
}
]
}
}
},
"groupings": [{
"groupingExpressions": [{
"selection": {
"directReference": {
"structField": {
"field": 0
}
}
}
}]
}],
"measures": [{
"measure": {
"functionReference": 0,
"arguments": [],
"sorts": [],
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
"invocation": "AGGREGATION_INVOCATION_ALL",
"outputType": {
"i64": {}
}
}
}]
}
}
}],
"extensionUris": [{
"extension_uri_anchor": 0,
"uri": "https://github.com/apache/arrow/blob/main/format/substrait/extension_types.yaml"
}],
"extensions": [{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "equal"
}
}],
})",
/*ignore_unknown_fields=*/false));
ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; }));
}
TEST(Substrait, AggregateBadPhase) {
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"aggregate": {
"input": {
"read": {
"base_schema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"local_files": {
"items": [
{
"uri_file": "file:///tmp/dat.parquet",
"parquet": {}
}
]
}
}
},
"groupings": [{
"groupingExpressions": [{
"selection": {
"directReference": {
"structField": {
"field": 0
}
}
}
}]
}],
"measures": [{
"measure": {
"functionReference": 0,
"arguments": [],
"sorts": [],
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
"invocation": "AGGREGATION_INVOCATION_DISTINCT",
"outputType": {
"i64": {}
}
}
}]
}
}
}],
"extensionUris": [{
"extension_uri_anchor": 0,
"uri": "https://github.com/apache/arrow/blob/main/format/substrait/extension_types.yaml"
}],
"extensions": [{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "equal"
}
}],
})",
/*ignore_unknown_fields=*/false));
ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; }));
}
TEST(SubstraitRoundTrip, BasicPlan) {
arrow::dataset::internal::Initialize();
auto dummy_schema = schema(
{field("key", int32()), field("shared", int32()), field("distinct", int32())});
// creating a dummy dataset using a dummy table
auto table = TableFromJSON(dummy_schema, {R"([
[1, 1, 10],
[3, 4, 20]
])",
R"([
[0, 2, 1],
[1, 3, 2],
[4, 1, 3],
[3, 1, 3],
[1, 2, 5]
])",
R"([
[2, 2, 12],
[5, 3, 12],
[1, 3, 12]
])"});
auto format = std::make_shared<arrow::dataset::IpcFileFormat>();
auto filesystem = std::make_shared<fs::LocalFileSystem>();
const std::string file_name = "serde_test.arrow";
ASSERT_OK_AND_ASSIGN(auto tempdir,
arrow::internal::TemporaryDir::Make("substrait-tempdir-"));
ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name));
std::string file_path_str = file_path.ToString();
WriteIpcData(file_path_str, filesystem, table);
std::vector<fs::FileInfo> files;
const std::vector<std::string> f_paths = {file_path_str};
for (const auto& f_path : f_paths) {
ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path));
files.push_back(std::move(f_file));
}
ASSERT_OK_AND_ASSIGN(auto ds_factory, dataset::FileSystemDatasetFactory::Make(
filesystem, std::move(files), format, {}));
ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema));
auto scan_options = std::make_shared<dataset::ScanOptions>();
scan_options->projection = compute::project({}, {});
const std::string filter_col_left = "shared";
const std::string filter_col_right = "distinct";
auto comp_left_value = compute::field_ref(filter_col_left);
auto comp_right_value = compute::field_ref(filter_col_right);
auto filter = compute::equal(comp_left_value, comp_right_value);
arrow::AsyncGenerator<std::optional<compute::ExecBatch>> sink_gen;
auto declarations = compute::Declaration::Sequence(
{compute::Declaration(
{"scan", dataset::ScanNodeOptions{dataset, scan_options}, "s"}),
compute::Declaration({"filter", compute::FilterNodeOptions{filter}, "f"}),
compute::Declaration({"sink", compute::SinkNodeOptions{&sink_gen}, "e"})});
std::shared_ptr<ExtensionIdRegistry> sp_ext_id_reg = MakeExtensionIdRegistry();
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set));
ASSERT_OK_AND_ASSIGN(
auto sink_decls,
DeserializePlans(
*serialized_plan, [] { return kNullConsumer; }, ext_id_reg, &ext_set));
// filter declaration
const auto& roundtripped_filter =
std::get<compute::Declaration>(sink_decls[0].inputs[0]);
const auto& filter_opts =
checked_cast<const compute::FilterNodeOptions&>(*(roundtripped_filter.options));
auto roundtripped_expr = filter_opts.filter_expression;
if (auto* call = roundtripped_expr.call()) {
EXPECT_EQ(call->function_name, "equal");
auto args = call->arguments;
auto left_index = args[0].field_ref()->field_path()->indices()[0];
EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left);
auto right_index = args[1].field_ref()->field_path()->indices()[0];
EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right);
}
// scan declaration
const auto& roundtripped_scan =
std::get<compute::Declaration>(roundtripped_filter.inputs[0]);
const auto& dataset_opts =
checked_cast<const dataset::ScanNodeOptions&>(*(roundtripped_scan.options));
const auto& roundripped_ds = dataset_opts.dataset;
EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema));
ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments());
ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments());
auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs));
auto expected_frg_vec = IteratorToVector(std::move(expected_frgs));
EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size());
int64_t idx = 0;
for (auto fragment : expected_frg_vec) {
const auto* l_frag = checked_cast<const dataset::FileFragment*>(fragment.get());
const auto* r_frag =
checked_cast<const dataset::FileFragment*>(roundtrip_frg_vec[idx++].get());
EXPECT_TRUE(l_frag->Equals(*r_frag));
}
}
TEST(SubstraitRoundTrip, BasicPlanEndToEnd) {
compute::ExecContext exec_context;
arrow::dataset::internal::Initialize();
auto dummy_schema = schema(
{field("key", int32()), field("shared", int32()), field("distinct", int32())});
// creating a dummy dataset using a dummy table
auto table = TableFromJSON(dummy_schema, {R"([
[1, 1, 10],
[3, 4, 4]
])",
R"([
[0, 2, 1],
[1, 3, 2],
[4, 1, 1],
[3, 1, 3],
[1, 2, 2]
])",
R"([
[2, 2, 12],
[5, 3, 12],
[1, 3, 3]
])"});
auto format = std::make_shared<arrow::dataset::IpcFileFormat>();
auto filesystem = std::make_shared<fs::LocalFileSystem>();
const std::string file_name = "serde_test.arrow";
ASSERT_OK_AND_ASSIGN(auto tempdir,
arrow::internal::TemporaryDir::Make("substrait-tempdir-"));
ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name));
std::string file_path_str = file_path.ToString();
WriteIpcData(file_path_str, filesystem, table);
std::vector<fs::FileInfo> files;
const std::vector<std::string> f_paths = {file_path_str};
for (const auto& f_path : f_paths) {
ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path));
files.push_back(std::move(f_file));
}
ASSERT_OK_AND_ASSIGN(auto ds_factory, dataset::FileSystemDatasetFactory::Make(
filesystem, std::move(files), format, {}));
ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema));
auto scan_options = std::make_shared<dataset::ScanOptions>();
scan_options->projection = compute::project({}, {});
const std::string filter_col_left = "shared";
const std::string filter_col_right = "distinct";
auto comp_left_value = compute::field_ref(filter_col_left);
auto comp_right_value = compute::field_ref(filter_col_right);
auto filter = compute::equal(comp_left_value, comp_right_value);
auto declarations = compute::Declaration::Sequence(
{compute::Declaration(
{"scan", dataset::ScanNodeOptions{dataset, scan_options}, "s"}),
compute::Declaration({"filter", compute::FilterNodeOptions{filter}, "f"})});
ASSERT_OK_AND_ASSIGN(auto expected_table, compute::DeclarationToTable(declarations));
std::shared_ptr<ExtensionIdRegistry> sp_ext_id_reg = MakeExtensionIdRegistry();
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set));
ASSERT_OK_AND_ASSIGN(
auto sink_decls,
DeserializePlans(
*serialized_plan, [] { return kNullConsumer; }, ext_id_reg, &ext_set));
// filter declaration
auto& roundtripped_filter = std::get<compute::Declaration>(sink_decls[0].inputs[0]);
const auto& filter_opts =
checked_cast<const compute::FilterNodeOptions&>(*(roundtripped_filter.options));
auto roundtripped_expr = filter_opts.filter_expression;
if (auto* call = roundtripped_expr.call()) {
EXPECT_EQ(call->function_name, "equal");
auto args = call->arguments;
auto left_index = args[0].field_ref()->field_path()->indices()[0];
EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left);
auto right_index = args[1].field_ref()->field_path()->indices()[0];
EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right);
}
// scan declaration
const auto& roundtripped_scan =
std::get<compute::Declaration>(roundtripped_filter.inputs[0]);
const auto& dataset_opts =
checked_cast<const dataset::ScanNodeOptions&>(*(roundtripped_scan.options));
const auto& roundripped_ds = dataset_opts.dataset;
EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema));
ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments());
ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments());
auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs));
auto expected_frg_vec = IteratorToVector(std::move(expected_frgs));
EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size());
int64_t idx = 0;
for (auto fragment : expected_frg_vec) {
const auto* l_frag = checked_cast<const dataset::FileFragment*>(fragment.get());
const auto* r_frag =
checked_cast<const dataset::FileFragment*>(roundtrip_frg_vec[idx++].get());
EXPECT_TRUE(l_frag->Equals(*r_frag));
}
ASSERT_OK_AND_ASSIGN(auto rnd_trp_table,
compute::DeclarationToTable(roundtripped_filter));
compute::AssertTablesEqualIgnoringOrder(expected_table, rnd_trp_table);
}
TEST(SubstraitRoundTrip, FilterNamedTable) {
arrow::dataset::internal::Initialize();
const std::vector<std::string> table_names{"table", "1"};
const auto dummy_schema =
schema({field("A", int32()), field("B", int32()), field("C", int32())});
auto filter = compute::equal(compute::field_ref("A"), compute::field_ref("B"));
auto declarations = compute::Declaration::Sequence(
{compute::Declaration({"named_table",
compute::NamedTableNodeOptions{table_names, dummy_schema},
"n"}),
compute::Declaration({"filter", compute::FilterNodeOptions{filter}, "f"})});
ExtensionSet ext_set{};
ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set));
// creating a dummy dataset using a dummy table
auto input_table = TableFromJSON(dummy_schema, {R"([
[1, 1, 10],
[3, 5, 20],
[4, 1, 30],
[2, 1, 40],
[5, 5, 50],
[2, 2, 60]
])"});
NamedTableProvider table_provider =
[&input_table, &table_names](
const std::vector<std::string>& names) -> Result<compute::Declaration> {
if (table_names != names) {
return Status::Invalid("Table name mismatch");
}
std::shared_ptr<compute::ExecNodeOptions> options =
std::make_shared<compute::TableSourceNodeOptions>(input_table);
return compute::Declaration("table_source", {}, std::move(options), "mock_source");
};
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
auto expected_table = TableFromJSON(dummy_schema, {R"([
[1, 1, 10],
[5, 5, 50],
[2, 2, 60]
])"});
CheckRoundTripResult(std::move(expected_table), serialized_plan,
/*include_columns=*/{}, conversion_options);
}
TEST(SubstraitRoundTrip, ProjectRel) {
compute::ExecContext exec_context;
auto dummy_schema =
schema({field("A", int32()), field("B", int32()), field("C", int32())});
// creating a dummy dataset using a dummy table
auto input_table = TableFromJSON(dummy_schema, {R"([
[1, 1, 10],
[3, 5, 20],
[4, 1, 30],
[2, 1, 40],
[5, 5, 50],
[2, 2, 60]
])"});
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"project": {
"expressions": [{
"scalarFunction": {
"functionReference": 0,
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}
}, {
"value": {
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {
}
}
}
}],
"output_type": {
"bool": {}
}
}
},
],
"input" : {
"read": {
"base_schema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"namedTable": {
"names": ["A"]
}
}
}
}
}
}],
"extension_uris": [
{
"extension_uri_anchor": 0,
"uri": ")" + std::string(kSubstraitComparisonFunctionsUri) +
R"("
}
],
"extensions": [
{"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "equal"
}}
]
})";
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
auto output_schema = schema({field("A", int32()), field("B", int32()),
field("C", int32()), field("equal", boolean())});
auto expected_table = TableFromJSON(output_schema, {R"([
[1, 1, 10, true],
[3, 5, 20, false],
[4, 1, 30, false],
[2, 1, 40, false],
[5, 5, 50, true],
[2, 2, 60, true]
])"});
NamedTableProvider table_provider = AlwaysProvideSameTable(std::move(input_table));
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
CheckRoundTripResult(std::move(expected_table), buf,
/*include_columns=*/{}, conversion_options);
}
TEST(SubstraitRoundTrip, ProjectRelOnFunctionWithEmit) {
auto dummy_schema =
schema({field("A", int32()), field("B", int32()), field("C", int32())});
// creating a dummy dataset using a dummy table
auto input_table = TableFromJSON(dummy_schema, {R"([
[1, 1, 10],
[3, 5, 20],
[4, 1, 30],
[2, 1, 40],
[5, 5, 50],
[2, 2, 60]
])"});
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"project": {
"common": {
"emit": {
"outputMapping": [0, 2, 3]
}
},
"expressions": [{
"scalarFunction": {
"functionReference": 0,
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}
}, {
"value": {
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {
}
}
}
}],
"output_type": {
"bool": {}
}
}
},
],
"input" : {
"read": {
"base_schema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"namedTable": {
"names": ["A"]
}
}
}
}
}
}],
"extension_uris": [
{
"extension_uri_anchor": 0,
"uri": ")" + std::string(kSubstraitComparisonFunctionsUri) +
R"("
}
],
"extensions": [
{"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "equal"
}}
]
})";
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
auto output_schema =
schema({field("A", int32()), field("C", int32()), field("equal", boolean())});
auto expected_table = TableFromJSON(output_schema, {R"([
[1, 10, true],
[3, 20, false],
[4, 30, false],
[2, 40, false],
[5, 50, true],
[2, 60, true]
])"});
NamedTableProvider table_provider = AlwaysProvideSameTable(std::move(input_table));
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
ValidateNumProjectNodes(1, buf, conversion_options);
CheckRoundTripResult(std::move(expected_table), buf,
/*include_columns=*/{}, conversion_options);
}
TEST(SubstraitRoundTrip, ProjectRelOnFunctionWithAllEmit) {
compute::ExecContext exec_context;
auto dummy_schema = schema({field("A", int32()), field("B", int32())});
// creating a dummy dataset using a dummy table
auto input_table = TableFromJSON(dummy_schema, {R"([
[1, 1],
[3, 5],
[4, 1],
[2, 1],
[5, 5],
[2, 2]
])"});
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations":[
{
"rel":{
"project":{
"common":{
"emit":{
"outputMapping":[
0,
1,
2,
3
]
}
},
"expressions":[
{
"scalarFunction":{
"functionReference":0,
"arguments":[
{
"value":{
"selection":{
"directReference":{
"structField":{
"field":0
}
},
"rootReference":{
}
}
}
},
{
"value":{
"selection":{
"directReference":{
"structField":{
"field":1
}
},
"rootReference":{
}
}
}
}
],
"output_type":{
"bool":{
}
}
}
}
],
"input":{
"project":{
"common":{
"emit":{
"outputMapping":[
0,
1,
2
]
}
},
"expressions":[
{
"scalarFunction":{
"functionReference":0,
"arguments":[
{
"value":{
"selection":{
"directReference":{
"structField":{
"field":0
}
},
"rootReference":{
}
}
}
},
{
"value":{
"selection":{
"directReference":{
"structField":{
"field":1
}
},
"rootReference":{
}
}
}
}
],
"output_type":{
"bool":{
}
}
}
}
],
"input":{
"read":{
"base_schema":{
"names":[
"A",
"B"
],
"struct":{
"types":[
{
"i32":{
}
},
{
"i32":{
}
}
]
}
},
"namedTable":{
"names":[
"TABLE"
]
}
}
}
}
}
}
}
}
],
"extension_uris": [
{
"extension_uri_anchor": 0,
"uri": ")" + std::string(kSubstraitComparisonFunctionsUri) +
R"("
}
],
"extensions": [
{"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "equal"
}}
]
})";
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
auto output_schema = schema({field("A", int32()), field("B", int32()),
field("eq1", boolean()), field("eq2", boolean())});
auto expected_table = TableFromJSON(output_schema, {R"([
[1, 1, true, true],
[3, 5, false, false],
[4, 1, false, false],
[2, 1, false, false],
[5, 5, true, true],
[2, 2, true, true]
])"});
NamedTableProvider table_provider = AlwaysProvideSameTable(std::move(input_table));
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
ValidateNumProjectNodes(2, buf, conversion_options);
CheckRoundTripResult(std::move(expected_table), buf,
/*include_columns=*/{}, conversion_options);
}
TEST(SubstraitRoundTrip, ReadRelWithEmit) {
auto dummy_schema =
schema({field("A", int32()), field("B", int32()), field("C", int32())});
// creating a dummy dataset using a dummy table
auto input_table = TableFromJSON(dummy_schema, {R"([
[1, 1, 10],
[3, 4, 20]
])"});
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"read": {
"common": {
"emit": {
"outputMapping": [1, 2]
}
},
"base_schema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"namedTable": {
"names" : ["A"]
}
}
}
}],
})";
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
auto output_schema = schema({field("B", int32()), field("C", int32())});
auto expected_table = TableFromJSON(output_schema, {R"([
[1, 10],
[4, 20]
])"});
NamedTableProvider table_provider = AlwaysProvideSameTable(std::move(input_table));
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
ValidateNumProjectNodes(1, buf, conversion_options);
CheckRoundTripResult(std::move(expected_table), buf,
/*include_columns=*/{}, conversion_options);
}
TEST(SubstraitRoundTrip, FilterRelWithEmit) {
auto dummy_schema = schema({field("A", int32()), field("B", int32()),
field("C", int32()), field("D", int32())});
// creating a dummy dataset using a dummy table
auto input_table = TableFromJSON(dummy_schema, {R"([
[10, 1, 80, 7],
[20, 2, 70, 6],
[30, 3, 30, 5],
[40, 4, 20, 4],
[40, 5, 40, 3],
[20, 6, 20, 2],
[30, 7, 30, 1]
])"});
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"filter": {
"common": {
"emit": {
"outputMapping": [1, 3]
}
},
"condition": {
"scalarFunction": {
"functionReference": 0,
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}
}, {
"value": {
"selection": {
"directReference": {
"structField": {
"field": 2
}
},
"rootReference": {
}
}
}
}],
"output_type": {
"bool": {}
}
}
},
"input" : {
"read": {
"base_schema": {
"names": ["A", "B", "C", "D"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
},{
"i32": {}
}]
}
},
"namedTable": {
"names" : ["A"]
}
}
}
}
}
}],
"extension_uris": [
{
"extension_uri_anchor": 0,
"uri": ")" + std::string(kSubstraitComparisonFunctionsUri) +
R"("
}
],
"extensions": [
{"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "equal"
}}
]
})";
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
auto output_schema = schema({field("B", int32()), field("D", int32())});
auto expected_table = TableFromJSON(output_schema, {R"([
[3, 5],
[5, 3],
[6, 2],
[7, 1]
])"});
NamedTableProvider table_provider = AlwaysProvideSameTable(std::move(input_table));
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
ValidateNumProjectNodes(1, buf, conversion_options);
CheckRoundTripResult(std::move(expected_table), buf,
/*include_columns=*/{}, conversion_options);
}
TEST(SubstraitRoundTrip, JoinRel) {
auto left_schema = schema({field("A", int32()), field("B", int32())});
auto right_schema = schema({field("X", int32()), field("Y", int32())});
// creating a dummy dataset using a dummy table
auto left_table = TableFromJSON(left_schema, {R"([
[10, 1],
[20, 2],
[30, 3]
])"});
auto right_table = TableFromJSON(right_schema, {R"([
[10, 11],
[80, 21],
[31, 31]
])"});
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"join": {
"left": {
"read": {
"base_schema": {
"names": ["A", "B"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}]
}
},
"namedTable": {
"names" : ["left"]
}
}
},
"right": {
"read": {
"base_schema": {
"names": ["X", "Y"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}]
}
},
"namedTable": {
"names" : ["right"]
}
}
},
"expression": {
"scalarFunction": {
"functionReference": 0,
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}
}, {
"value": {
"selection": {
"directReference": {
"structField": {
"field": 2
}
},
"rootReference": {
}
}
}
}],
"output_type": {
"bool": {}
}
}
},
"type": "JOIN_TYPE_INNER"
}
}
}],
"extension_uris": [
{
"extension_uri_anchor": 0,
"uri": ")" + std::string(kSubstraitComparisonFunctionsUri) +
R"("
}
],
"extensions": [
{"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "equal"
}}
]
})";
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
// include these columns for comparison
auto output_schema = schema({
field("A", int32()),
field("B", int32()),
field("X", int32()),
field("Y", int32()),
});
auto expected_table = TableFromJSON(std::move(output_schema), {R"([
[10, 1, 10, 11]
])"});
NamedTableProvider table_provider =
[left_table, right_table](const std::vector<std::string>& names) {
std::shared_ptr<Table> output_table;
for (const auto& name : names) {
if (name == "left") {
output_table = left_table;
}
if (name == "right") {
output_table = right_table;
}
}
std::shared_ptr<compute::ExecNodeOptions> options =
std::make_shared<compute::TableSourceNodeOptions>(std::move(output_table));
return compute::Declaration("table_source", {}, options, "mock_source");
};
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
CheckRoundTripResult(std::move(expected_table), buf,
/*include_columns=*/{}, conversion_options);
}
TEST(SubstraitRoundTrip, JoinRelWithEmit) {
auto left_schema = schema({field("A", int32()), field("B", int32())});
auto right_schema = schema({field("X", int32()), field("Y", int32())});
// creating a dummy dataset using a dummy table
auto left_table = TableFromJSON(left_schema, {R"([
[10, 1],
[20, 2],
[30, 3]
])"});
auto right_table = TableFromJSON(right_schema, {R"([
[10, 11],
[80, 21],
[31, 31]
])"});
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"join": {
"common": {
"emit": {
"outputMapping": [0, 1, 3]
}
},
"left": {
"read": {
"base_schema": {
"names": ["A", "B"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}]
}
},
"namedTable" : {
"names" : ["left"]
}
}
},
"right": {
"read": {
"base_schema": {
"names": ["X", "Y"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}]
}
},
"namedTable" : {
"names" : ["right"]
}
}
},
"expression": {
"scalarFunction": {
"functionReference": 0,
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}
}, {
"value": {
"selection": {
"directReference": {
"structField": {
"field": 2
}
},
"rootReference": {
}
}
}
}],
"output_type": {
"bool": {}
}
}
},
"type": "JOIN_TYPE_INNER"
}
}
}],
"extension_uris": [
{
"extension_uri_anchor": 0,
"uri": ")" + std::string(kSubstraitComparisonFunctionsUri) +
R"("
}
],
"extensions": [
{"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "equal"
}}
]
})";
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
auto output_schema = schema({
field("A", int32()),
field("B", int32()),
field("Y", int32()),
});
auto expected_table = TableFromJSON(std::move(output_schema), {R"([
[10, 1, 11]
])"});
NamedTableProvider table_provider =
[left_table, right_table](const std::vector<std::string>& names) {
std::shared_ptr<Table> output_table;
for (const auto& name : names) {
if (name == "left") {
output_table = left_table;
}
if (name == "right") {
output_table = right_table;
}
}
std::shared_ptr<compute::ExecNodeOptions> options =
std::make_shared<compute::TableSourceNodeOptions>(std::move(output_table));
return compute::Declaration("table_source", {}, options, "mock_source");
};
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
ValidateNumProjectNodes(1, buf, conversion_options);
CheckRoundTripResult(std::move(expected_table), buf,
/*include_columns=*/{}, conversion_options);
}
TEST(SubstraitRoundTrip, AggregateRel) {
auto dummy_schema =
schema({field("A", int32()), field("B", int32()), field("C", int32())});
// creating a dummy dataset using a dummy table
auto input_table = TableFromJSON(dummy_schema, {R"([
[10, 1, 80],
[20, 2, 70],
[30, 3, 30],
[40, 4, 20],
[40, 5, 40],
[20, 6, 20],
[30, 7, 30]
])"});
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"aggregate": {
"input": {
"read": {
"base_schema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"namedTable" : {
"names": ["A"]
}
}
},
"groupings": [{
"groupingExpressions": [{
"selection": {
"directReference": {
"structField": {
"field": 0
}
}
}
}]
}],
"measures": [{
"measure": {
"functionReference": 0,
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 2
}
}
}
}
}],
"sorts": [],
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
"invocation": "AGGREGATION_INVOCATION_ALL",
"outputType": {
"i64": {}
}
}
}]
}
}
}],
"extensionUris": [{
"extension_uri_anchor": 0,
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
}],
"extensions": [{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "sum"
}
}],
})";
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
auto output_schema = schema({field("aggregates", int64()), field("keys", int32())});
auto expected_table = TableFromJSON(output_schema, {R"([
[80, 10],
[90, 20],
[60, 30],
[60, 40]
])"});
NamedTableProvider table_provider = AlwaysProvideSameTable(std::move(input_table));
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
CheckRoundTripResult(std::move(expected_table), buf,
/*include_columns=*/{}, conversion_options);
}
TEST(SubstraitRoundTrip, AggregateRelEmit) {
auto dummy_schema =
schema({field("A", int32()), field("B", int32()), field("C", int32())});
// creating a dummy dataset using a dummy table
auto input_table = TableFromJSON(dummy_schema, {R"([
[10, 1, 80],
[20, 2, 70],
[30, 3, 30],
[40, 4, 20],
[40, 5, 40],
[20, 6, 20],
[30, 7, 30]
])"});
// TODO: fixme https://issues.apache.org/jira/browse/ARROW-17484
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"relations": [{
"rel": {
"aggregate": {
"common": {
"emit": {
"outputMapping": [0]
}
},
"input": {
"read": {
"base_schema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"namedTable" : {
"names" : ["A"]
}
}
},
"groupings": [{
"groupingExpressions": [{
"selection": {
"directReference": {
"structField": {
"field": 0
}
}
}
}]
}],
"measures": [{
"measure": {
"functionReference": 0,
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 2
}
}
}
}
}],
"sorts": [],
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
"invocation": "AGGREGATION_INVOCATION_ALL",
"outputType": {
"i64": {}
}
}
}]
}
}
}],
"extensionUris": [{
"extension_uri_anchor": 0,
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
}],
"extensions": [{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "sum"
}
}],
})";
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
auto output_schema = schema({field("aggregates", int64())});
auto expected_table = TableFromJSON(output_schema, {R"([
[80],
[90],
[60],
[60]
])"});
NamedTableProvider table_provider = AlwaysProvideSameTable(std::move(input_table));
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
ValidateNumProjectNodes(1, buf, conversion_options);
CheckRoundTripResult(std::move(expected_table), buf,
/*include_columns=*/{}, conversion_options);
}
TEST(Substrait, IsthmusPlan) {
// This is a plan generated from Isthmus
// isthmus -c "CREATE TABLE T1(foo int)" "SELECT foo + 1 FROM T1"
std::string substrait_json = R"({
"version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 },
"extensionUris": [{
"extensionUriAnchor": 1,
"uri": "/functions_arithmetic.yaml"
}],
"extensions": [{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 0,
"name": "add:i32_i32"
}
}],
"relations": [{
"root": {
"input": {
"project": {
"common": {
"emit": {
"outputMapping": [1]
}
},
"input": {
"read": {
"common": {
"direct": {
}
},
"baseSchema": {
"names": ["FOO"],
"struct": {
"types": [{
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
}],
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": ["T1"]
}
}
},
"expressions": [{
"scalarFunction": {
"functionReference": 0,
"outputType": {
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
},
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}
}, {
"value": {
"literal": {
"i32": 1,
"nullable": false,
"typeVariationReference": 0
}
}
}]
}
}]
}
},
"names": ["EXPR$0"]
}
}],
"expectedTypeUrls": []
})";
auto test_schema = schema({field("foo", int32())});
auto input_table = TableFromJSON(test_schema, {"[[1], [2], [5]]"});
NamedTableProvider table_provider = AlwaysProvideSameTable(std::move(input_table));
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
ASSERT_OK_AND_ASSIGN(auto buf,
internal::SubstraitFromJSON("Plan", substrait_json,
/*ignore_unknown_fields=*/false));
ValidateNumProjectNodes(1, buf, conversion_options);
auto expected_table = TableFromJSON(test_schema, {"[[2], [3], [6]]"});
CheckRoundTripResult(std::move(expected_table), buf,
/*include_columns=*/{}, conversion_options);
}
NamedTableProvider ProvideMadeTable(
std::function<Result<std::shared_ptr<Table>>(const std::vector<std::string>&)> make) {
return [make](const std::vector<std::string>& names) -> Result<compute::Declaration> {
ARROW_ASSIGN_OR_RAISE(auto table, make(names));
std::shared_ptr<compute::ExecNodeOptions> options =
std::make_shared<compute::TableSourceNodeOptions>(table);
return compute::Declaration("table_source", {}, options, "mock_source");
};
}
TEST(Substrait, ProjectWithMultiFieldExpressions) {
auto dummy_schema =
schema({field("A", int32()), field("B", int32()), field("C", int32())});
// creating a dummy dataset using a dummy table
auto input_table = TableFromJSON(dummy_schema, {R"([
[10, 1, 80],
[20, 2, 70],
[30, 3, 30]
])"});
const std::string substrait_json = R"({
"extensionUris": [{
"extensionUriAnchor": 1,
"uri": "/functions_arithmetic.yaml"
}],
"extensions": [{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 0,
"name": "add:i32_i32"
}
}],
"relations": [{
"root": {
"input": {
"project": {
"common": {
"emit": {
"outputMapping": [0, 3, 6]
}
},
"input": {
"read": {
"common": {
"direct": {
}
},
"baseSchema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
}, {
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
}, {
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
}],
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": ["SIMPLEDATA"]
}
}
},
"expressions": [{
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}, {
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {
}
}
}, {
"selection": {
"directReference": {
"structField": {
"field": 2
}
},
"rootReference": {
}
}
},{
"scalarFunction": {
"functionReference": 0,
"outputType": {
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
},
"arguments": [{
"value": {
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}
}, {
"value": {
"literal": {
"i32": 1,
"nullable": false,
"typeVariationReference": 0
}
}
}]
}
}]
}
},
"names": ["A", "B", "C", "D"]
}
}]
})";
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
auto output_schema =
schema({field("A", int32()), field("A1", int32()), field("A+1", int32())});
auto expected_table = TableFromJSON(output_schema, {R"([
[10, 10, 11],
[20, 20, 21],
[30, 30, 31]
])"});
NamedTableProvider table_provider = AlwaysProvideSameTable(std::move(input_table));
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
ValidateNumProjectNodes(1, buf, conversion_options);
CheckRoundTripResult(std::move(expected_table), buf,
/*include_columns=*/{}, conversion_options);
}
TEST(Substrait, NestedProjectWithMultiFieldExpressions) {
auto dummy_schema = schema({field("A", int32())});
// creating a dummy dataset using a dummy table
auto input_table = TableFromJSON(dummy_schema, {R"([
[10],
[20],
[30]
])"});
const std::string substrait_json = R"({
"extensionUris": [
{
"extensionUriAnchor": 1,
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
}
],
"extensions": [
{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 2,
"name": "add"
}
}
],
"relations": [
{
"rel": {
"project": {
"input": {
"project": {
"common": {"emit": {"outputMapping": [2]}},
"input": {
"read": {
"baseSchema": {
"names": ["int"],
"struct": {"types": [{"i32": {}}]}
},
"namedTable": {
"names": ["SIMPLEDATA"]
}
}
},
"expressions": [
{"selection": {"directReference": {"structField": {"field": 0}}}},
{
"scalarFunction": {
"functionReference": 2,
"outputType": {"i32": {}},
"arguments": [
{"value": {"selection": {"directReference": {"structField": {"field": 0}}}}},
{"value": {"literal": {"fp64": 10}}}
]
}
}
]
}
},
"expressions": [
{"selection": {"directReference": {"structField": {"field": 0}}}}
]
}
}
}
]
})";
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
auto output_schema = schema({field("A", float64()), field("B", float64())});
auto expected_table = TableFromJSON(output_schema, {R"([
[20, 20],
[30, 30],
[40, 40]
])"});
NamedTableProvider table_provider = AlwaysProvideSameTable(std::move(input_table));
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
ValidateNumProjectNodes(2, buf, conversion_options);
CheckRoundTripResult(std::move(expected_table), buf,
/*include_columns=*/{}, conversion_options);
}
TEST(Substrait, NestedEmitProjectWithMultiFieldExpressions) {
auto dummy_schema = schema({field("A", int32())});
// creating a dummy dataset using a dummy table
auto input_table = TableFromJSON(dummy_schema, {R"([
[10],
[20],
[30]
])"});
const std::string substrait_json = R"({
"extensionUris": [
{
"extensionUriAnchor": 1,
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
}
],
"extensions": [
{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 2,
"name": "add"
}
}
],
"relations": [
{
"rel": {
"project": {
"common": {"emit": {"outputMapping": [2]}},
"input": {
"project": {
"common": {"emit": {"outputMapping": [1, 2]}},
"input": {
"read": {
"baseSchema": {
"names": ["int"],
"struct": {"types": [{"i32": {}}]}
},
"namedTable": {
"names": ["SIMPLEDATA"]
}
}
},
"expressions": [
{"selection": {"directReference": {"structField": {"field": 0}}}},
{
"scalarFunction": {
"functionReference": 2,
"outputType": {"i32": {}},
"arguments": [
{"value": {"selection": {"directReference": {"structField": {"field": 0}}}}},
{"value": {"literal": {"fp64": 10}}}
]
}
}
]
}
},
"expressions": [
{"selection": {"directReference": {"structField": {"field": 0}}}}
]
}
}
}
]
})";
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
auto output_schema = schema({field("A", int32())});
auto expected_table = TableFromJSON(output_schema, {R"([
[10],
[20],
[30]
])"});
NamedTableProvider table_provider = AlwaysProvideSameTable(std::move(input_table));
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
ValidateNumProjectNodes(2, buf, conversion_options);
CheckRoundTripResult(std::move(expected_table), buf,
/*include_columns=*/{}, conversion_options);
}
TEST(Substrait, ReadRelWithGlobFiles) {
#ifdef _WIN32
GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
#endif
arrow::dataset::internal::Initialize();
auto dummy_schema =
schema({field("A", int32()), field("B", int32()), field("C", int32())});
// creating a dummy dataset using a dummy table
auto table_1 = TableFromJSON(dummy_schema, {R"([
[1, 1, 10],
[3, 4, 20]
])"});
auto table_2 = TableFromJSON(dummy_schema, {R"([
[11, 11, 110],
[13, 14, 120]
])"});
auto table_3 = TableFromJSON(dummy_schema, {R"([
[21, 21, 210],
[23, 24, 220]
])"});
auto expected_table = TableFromJSON(dummy_schema, {R"([
[1, 1, 10],
[3, 4, 20],
[11, 11, 110],
[13, 14, 120],
[21, 21, 210],
[23, 24, 220]
])"});
std::vector<std::shared_ptr<Table>> input_tables = {table_1, table_2, table_3};
auto format = std::make_shared<arrow::dataset::IpcFileFormat>();
auto filesystem = std::make_shared<fs::LocalFileSystem>();
const std::vector<std::string> file_names = {"serde_test_1.arrow", "serde_test_2.arrow",
"serde_test_3.arrow"};
const std::string path_prefix = "substrait-globfiles-";
int idx = 0;
// creating a vector to avoid out-of-scoping Temporary directory
// if out-of-scoped the written folder get wiped out
std::vector<std::unique_ptr<arrow::internal::TemporaryDir>> tempdirs;
for (size_t i = 0; i < file_names.size(); i++) {
ASSERT_OK_AND_ASSIGN(auto tempdir, arrow::internal::TemporaryDir::Make(path_prefix));
tempdirs.push_back(std::move(tempdir));
}
std::string sample_tempdir_path = tempdirs[0]->path().ToString();
std::string base_tempdir_path =
sample_tempdir_path.substr(0, sample_tempdir_path.find(path_prefix));
std::string glob_like_path =
"file://" + base_tempdir_path + path_prefix + "*/serde_test_*.arrow";
for (const auto& file_name : file_names) {
ASSERT_OK_AND_ASSIGN(auto file_path, tempdirs[idx]->path().Join(file_name));
std::string file_path_str = file_path.ToString();
WriteIpcData(file_path_str, filesystem, input_tables[idx++]);
}
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
"relations": [{
"rel": {
"read": {
"base_schema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {}
}, {
"i32": {}
}, {
"i32": {}
}]
}
},
"local_files": {
"items": [
{
"uri_path_glob": ")" + glob_like_path +
R"(",
"arrow": {}
}
]
}
}
}
}]
})"));
// To avoid unnecessar metadata columns being included in the final result
std::vector<int> include_columns = {0, 1, 2};
compute::SortOptions options({compute::SortKey("A", compute::SortOrder::Ascending)});
CheckRoundTripResult(std::move(expected_table), buf, std::move(include_columns),
/*conversion_options=*/{}, &options);
}
TEST(Substrait, RootRelationOutputNames) {
auto dummy_schema =
schema({field("A", int32()), field("B", int32()), field("C", int32())});
// creating a dummy dataset using a dummy table
const std::vector<std::string> str_data_vec = {
R"([
[10, 1, 80],
[20, 2, 70],
[30, 3, 30]
])"};
auto input_table = TableFromJSON(dummy_schema, str_data_vec);
const std::string substrait_json = R"({
"relations": [{
"root": {
"input": {
"project": {
"common": {
"emit": {
"outputMapping": [3, 4, 5]
}
},
"input": {
"read": {
"common": {
"direct": {
}
},
"baseSchema": {
"names": ["A", "B", "C"],
"struct": {
"types": [{
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
}, {
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
}, {
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
}],
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": ["SIMPLEDATA"]
}
}
},
"expressions": [{
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {
}
}
}, {
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {
}
}
}, {
"selection": {
"directReference": {
"structField": {
"field": 2
}
},
"rootReference": {
}
}
}]
}
},
"names": ["X", "Y", "Z"]
}
}]
})";
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
auto output_schema =
schema({field("X", int32()), field("Y", int32()), field("Z", int32())});
auto expected_table = TableFromJSON(output_schema, str_data_vec);
NamedTableProvider table_provider = AlwaysProvideSameTable(std::move(input_table));
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
ValidateNumProjectNodes(1, buf, conversion_options);
CheckRoundTripResult(std::move(expected_table), buf,
/*include_columns=*/{}, conversion_options);
}
TEST(Substrait, SetRelationBasic) {
auto dummy_schema =
schema({field("A", int32()), field("B", int32()), field("C", int32())});
// creating a dummy dataset using a dummy table
auto table1 = TableFromJSON(dummy_schema, {R"([
[10, 1, 80],
[20, 2, 70],
[30, 3, 30],
[40, 4, 20],
[50, 6, 20],
[200, 7, 30]
])"});
auto table2 = TableFromJSON(dummy_schema, {R"([
[70, 1, 82],
[80, 2, 72],
[90, 3, 32],
[100, 4, 22],
[110, 5, 42],
[111, 6, 22],
[112, 7, 32]
])"});
NamedTableProvider table_provider = [table1,
table2](const std::vector<std::string>& names) {
std::shared_ptr<Table> output_table;
for (const auto& name : names) {
if (name == "T1") {
output_table = table1;
}
if (name == "T2") {
output_table = table2;
}
}
std::shared_ptr<compute::ExecNodeOptions> options =
std::make_shared<compute::TableSourceNodeOptions>(std::move(output_table));
return compute::Declaration("table_source", {}, options, "mock_source");
};
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
std::string substrait_json = R"({
"relations": [{
"root": {
"input": {
"set": {
"inputs": [{
"read": {
"baseSchema": {
"names": ["FOO"],
"struct": {
"types": [{
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
}],
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": ["T1"]
}
}
}, {
"read": {
"baseSchema": {
"names": ["BAR"],
"struct": {
"types": [{
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
}],
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": ["T2"]
}
}
}],
"op": "SET_OP_UNION_ALL"
}
},
"names": ["FOO"]
}
}]
})";
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
auto expected_table = TableFromJSON(dummy_schema, {R"([
[10, 1, 80],
[20, 2, 70],
[30, 3, 30],
[40, 4, 20],
[50, 6, 20],
[70, 1, 82],
[80, 2, 72],
[90, 3, 32],
[100, 4, 22],
[110, 5, 42],
[111, 6, 22],
[112, 7, 32],
[200, 7, 30]
])"});
compute::SortOptions sort_options(
{compute::SortKey("A", compute::SortOrder::Ascending)});
CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options,
&sort_options);
}
TEST(Substrait, PlanWithAsOfJoinExtension) {
// This demos an extension relation
std::string substrait_json = R"({
"extensionUris": [],
"extensions": [],
"relations": [{
"root": {
"input": {
"extension_multi": {
"common": {
"emit": {
"outputMapping": [0, 1, 2, 3]
}
},
"inputs": [
{
"read": {
"common": {
"direct": {
}
},
"baseSchema": {
"names": ["time", "key", "value1"],
"struct": {
"types": [
{
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"fp64": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
}
],
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": ["T1"]
}
}
},
{
"read": {
"common": {
"direct": {
}
},
"baseSchema": {
"names": ["time", "key", "value2"],
"struct": {
"types": [
{
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"fp64": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
}
],
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": ["T2"]
}
}
}
],
"detail": {
"@type": "/arrow.substrait_ext.AsOfJoinRel",
"keys" : [
{
"on": {
"selection": {
"directReference": {
"structField": {
"field": 0,
}
},
"rootReference": {}
}
},
"by": [
{
"selection": {
"directReference": {
"structField": {
"field": 1,
}
},
"rootReference": {}
}
}
]
},
{
"on": {
"selection": {
"directReference": {
"structField": {
"field": 0,
}
},
"rootReference": {}
}
},
"by": [
{
"selection": {
"directReference": {
"structField": {
"field": 1,
}
},
"rootReference": {}
}
}
]
}
],
"tolerance": 1000
}
}
},
"names": ["time", "key", "value1", "value2"]
}
}],
"expectedTypeUrls": []
})";
std::vector<std::shared_ptr<Schema>> input_schema = {
schema({field("time", int32()), field("key", int32()), field("value1", float64())}),
schema(
{field("time", int32()), field("key", int32()), field("value2", float64())})};
NamedTableProvider table_provider = ProvideMadeTable(
[&input_schema](
const std::vector<std::string>& names) -> Result<std::shared_ptr<Table>> {
if (names.size() != 1) {
return Status::Invalid("Multiple test table names");
}
if (names[0] == "T1") {
return TableFromJSON(input_schema[0],
{"[[2, 1, 1.1], [4, 1, 2.1], [6, 2, 3.1]]"});
}
if (names[0] == "T2") {
return TableFromJSON(input_schema[1],
{"[[1, 1, 1.2], [3, 2, 2.2], [5, 2, 3.2]]"});
}
return Status::Invalid("Unknown test table name ", names[0]);
});
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
ValidateNumProjectNodes(1, buf, conversion_options);
ASSERT_OK_AND_ASSIGN(
auto out_schema,
compute::asofjoin::MakeOutputSchema(
input_schema, {{FieldRef(0), {FieldRef(1)}}, {FieldRef(0), {FieldRef(1)}}}));
auto expected_table = TableFromJSON(
out_schema, {"[[2, 1, 1.1, 1.2], [4, 1, 2.1, 1.2], [6, 2, 3.1, 3.2]]"});
CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options);
}
} // namespace engine
} // namespace arrow