| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #include "arrow/dataset/scanner.h" |
| |
| #include <memory> |
| |
| #include <gmock/gmock.h> |
| |
| #include "arrow/compute/api_scalar.h" |
| #include "arrow/compute/api_vector.h" |
| #include "arrow/compute/cast.h" |
| #include "arrow/dataset/scanner_internal.h" |
| #include "arrow/dataset/test_util.h" |
| #include "arrow/record_batch.h" |
| #include "arrow/table.h" |
| #include "arrow/testing/generator.h" |
| #include "arrow/testing/gtest_util.h" |
| #include "arrow/testing/util.h" |
| #include "arrow/util/range.h" |
| |
| using testing::ElementsAre; |
| using testing::IsEmpty; |
| |
| namespace arrow { |
| namespace dataset { |
| |
| constexpr int64_t kNumberChildDatasets = 2; |
| constexpr int64_t kNumberBatches = 16; |
| constexpr int64_t kBatchSize = 1024; |
| |
| class TestScanner : public DatasetFixtureMixin, |
| public ::testing::WithParamInterface<bool> { |
| protected: |
| bool UseThreads() { return GetParam(); } |
| |
| std::shared_ptr<Scanner> MakeScanner(std::shared_ptr<RecordBatch> batch) { |
| std::vector<std::shared_ptr<RecordBatch>> batches{static_cast<size_t>(kNumberBatches), |
| batch}; |
| |
| DatasetVector children{static_cast<size_t>(kNumberChildDatasets), |
| std::make_shared<InMemoryDataset>(batch->schema(), batches)}; |
| |
| EXPECT_OK_AND_ASSIGN(auto dataset, UnionDataset::Make(batch->schema(), children)); |
| |
| ScannerBuilder builder(dataset, options_); |
| ARROW_EXPECT_OK(builder.UseThreads(UseThreads())); |
| EXPECT_OK_AND_ASSIGN(auto scanner, builder.Finish()); |
| return scanner; |
| } |
| |
| void AssertScannerEqualsRepetitionsOf( |
| std::shared_ptr<Scanner> scanner, std::shared_ptr<RecordBatch> batch, |
| const int64_t total_batches = kNumberChildDatasets * kNumberBatches) { |
| auto expected = ConstantArrayGenerator::Repeat(total_batches, batch); |
| |
| // Verifies that the unified BatchReader is equivalent to flattening all the |
| // structures of the scanner, i.e. Scanner[Dataset[ScanTask[RecordBatch]]] |
| AssertScannerEquals(expected.get(), scanner.get()); |
| } |
| |
| void AssertScanBatchesEqualRepetitionsOf( |
| std::shared_ptr<Scanner> scanner, std::shared_ptr<RecordBatch> batch, |
| const int64_t total_batches = kNumberChildDatasets * kNumberBatches) { |
| auto expected = ConstantArrayGenerator::Repeat(total_batches, batch); |
| |
| AssertScanBatchesEquals(expected.get(), scanner.get()); |
| } |
| |
| void AssertScanBatchesUnorderedEqualRepetitionsOf( |
| std::shared_ptr<Scanner> scanner, std::shared_ptr<RecordBatch> batch, |
| const int64_t total_batches = kNumberChildDatasets * kNumberBatches) { |
| auto expected = ConstantArrayGenerator::Repeat(total_batches, batch); |
| |
| AssertScanBatchesUnorderedEquals(expected.get(), scanner.get()); |
| } |
| }; |
| |
| TEST_P(TestScanner, Scan) { |
| SetSchema({field("i32", int32()), field("f64", float64())}); |
| auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); |
| AssertScannerEqualsRepetitionsOf(MakeScanner(batch), batch); |
| } |
| |
| TEST_P(TestScanner, ScanBatches) { |
| SetSchema({field("i32", int32()), field("f64", float64())}); |
| auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); |
| AssertScanBatchesEqualRepetitionsOf(MakeScanner(batch), batch); |
| } |
| |
| TEST_P(TestScanner, ScanBatchesUnordered) { |
| SetSchema({field("i32", int32()), field("f64", float64())}); |
| auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); |
| AssertScanBatchesUnorderedEqualRepetitionsOf(MakeScanner(batch), batch); |
| } |
| |
| TEST_P(TestScanner, ScanWithCappedBatchSize) { |
| SetSchema({field("i32", int32()), field("f64", float64())}); |
| auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); |
| options_->batch_size = kBatchSize / 2; |
| auto expected = batch->Slice(kBatchSize / 2); |
| AssertScannerEqualsRepetitionsOf(MakeScanner(batch), expected, |
| kNumberChildDatasets * kNumberBatches * 2); |
| } |
| |
| TEST_P(TestScanner, FilteredScan) { |
| SetSchema({field("f64", float64())}); |
| |
| double value = 0.5; |
| ASSERT_OK_AND_ASSIGN(auto f64, |
| ArrayFromBuilderVisitor(float64(), kBatchSize, kBatchSize / 2, |
| [&](DoubleBuilder* builder) { |
| builder->UnsafeAppend(value); |
| builder->UnsafeAppend(-value); |
| value += 1.0; |
| })); |
| |
| SetFilter(greater(field_ref("f64"), literal(0.0))); |
| |
| auto batch = RecordBatch::Make(schema_, f64->length(), {f64}); |
| |
| value = 0.5; |
| ASSERT_OK_AND_ASSIGN( |
| auto f64_filtered, |
| ArrayFromBuilderVisitor(float64(), kBatchSize / 2, [&](DoubleBuilder* builder) { |
| builder->UnsafeAppend(value); |
| value += 1.0; |
| })); |
| |
| auto filtered_batch = |
| RecordBatch::Make(schema_, f64_filtered->length(), {f64_filtered}); |
| |
| AssertScannerEqualsRepetitionsOf(MakeScanner(batch), filtered_batch); |
| } |
| |
| TEST_P(TestScanner, MaterializeMissingColumn) { |
| SetSchema({field("i32", int32()), field("f64", float64())}); |
| auto batch_missing_f64 = |
| ConstantArrayGenerator::Zeroes(kBatchSize, schema({field("i32", int32())})); |
| |
| auto fragment_missing_f64 = std::make_shared<InMemoryFragment>( |
| RecordBatchVector{static_cast<size_t>(kNumberChildDatasets * kNumberBatches), |
| batch_missing_f64}, |
| equal(field_ref("f64"), literal(2.5))); |
| |
| ASSERT_OK_AND_ASSIGN(auto f64, ArrayFromBuilderVisitor(float64(), kBatchSize, |
| [&](DoubleBuilder* builder) { |
| builder->UnsafeAppend(2.5); |
| })); |
| auto batch_with_f64 = |
| RecordBatch::Make(schema_, f64->length(), {batch_missing_f64->column(0), f64}); |
| |
| ScannerBuilder builder{schema_, fragment_missing_f64, options_}; |
| ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish()); |
| |
| AssertScannerEqualsRepetitionsOf(scanner, batch_with_f64); |
| } |
| |
| TEST_P(TestScanner, ToTable) { |
| SetSchema({field("i32", int32()), field("f64", float64())}); |
| auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); |
| std::vector<std::shared_ptr<RecordBatch>> batches{kNumberBatches * kNumberChildDatasets, |
| batch}; |
| |
| ASSERT_OK_AND_ASSIGN(auto expected, Table::FromRecordBatches(batches)); |
| |
| auto scanner = MakeScanner(batch); |
| std::shared_ptr<Table> actual; |
| |
| // There is no guarantee on the ordering when using multiple threads, but |
| // since the RecordBatch is always the same it will pass. |
| ASSERT_OK_AND_ASSIGN(actual, scanner->ToTable()); |
| AssertTablesEqual(*expected, *actual); |
| } |
| |
| TEST_P(TestScanner, ScanWithVisitor) { |
| SetSchema({field("i32", int32()), field("f64", float64())}); |
| auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); |
| auto scanner = MakeScanner(batch); |
| ASSERT_OK(scanner->Scan([batch](TaggedRecordBatch scanned_batch) { |
| AssertBatchesEqual(*batch, *scanned_batch.record_batch); |
| return Status::OK(); |
| })); |
| } |
| |
| TEST_P(TestScanner, TakeIndices) { |
| SetSchema({field("i32", int32()), field("f64", float64())}); |
| ArrayVector arrays(2); |
| ArrayFromVector<Int32Type>(internal::Iota<int32_t>(kBatchSize), &arrays[0]); |
| ArrayFromVector<DoubleType>(internal::Iota<double>(static_cast<double>(kBatchSize)), |
| &arrays[1]); |
| auto batch = RecordBatch::Make(schema_, kBatchSize, arrays); |
| |
| auto scanner = MakeScanner(batch); |
| |
| std::shared_ptr<Array> indices; |
| { |
| ArrayFromVector<Int64Type>(internal::Iota(kBatchSize), &indices); |
| ASSERT_OK_AND_ASSIGN(auto taken, scanner->TakeRows(*indices)); |
| ASSERT_OK_AND_ASSIGN(auto expected, Table::FromRecordBatches({batch})); |
| ASSERT_EQ(expected->num_rows(), kBatchSize); |
| AssertTablesEqual(*expected, *taken); |
| } |
| { |
| ArrayFromVector<Int64Type>({7, 5, 3, 1}, &indices); |
| ASSERT_OK_AND_ASSIGN(auto taken, scanner->TakeRows(*indices)); |
| ASSERT_OK_AND_ASSIGN(auto table, scanner->ToTable()); |
| ASSERT_OK_AND_ASSIGN(auto expected, compute::Take(table, *indices)); |
| ASSERT_EQ(expected.table()->num_rows(), 4); |
| AssertTablesEqual(*expected.table(), *taken); |
| } |
| { |
| ArrayFromVector<Int64Type>({kBatchSize + 2, kBatchSize + 1}, &indices); |
| ASSERT_OK_AND_ASSIGN(auto table, scanner->ToTable()); |
| ASSERT_OK_AND_ASSIGN(auto taken, scanner->TakeRows(*indices)); |
| ASSERT_OK_AND_ASSIGN(auto expected, compute::Take(table, *indices)); |
| ASSERT_EQ(expected.table()->num_rows(), 2); |
| AssertTablesEqual(*expected.table(), *taken); |
| } |
| { |
| ArrayFromVector<Int64Type>({1, 3, 5, 7, kBatchSize + 1, 2 * kBatchSize + 2}, |
| &indices); |
| ASSERT_OK_AND_ASSIGN(auto taken, scanner->TakeRows(*indices)); |
| ASSERT_OK_AND_ASSIGN(auto table, scanner->ToTable()); |
| ASSERT_OK_AND_ASSIGN(auto expected, compute::Take(table, *indices)); |
| ASSERT_EQ(expected.table()->num_rows(), 6); |
| AssertTablesEqual(*expected.table(), *taken); |
| } |
| { |
| auto base = kNumberChildDatasets * kNumberBatches * kBatchSize; |
| ArrayFromVector<Int64Type>({base + 1}, &indices); |
| EXPECT_RAISES_WITH_MESSAGE_THAT( |
| IndexError, ::testing::HasSubstr("Some indices were out of bounds: 32769"), |
| scanner->TakeRows(*indices)); |
| } |
| { |
| auto base = kNumberChildDatasets * kNumberBatches * kBatchSize; |
| ArrayFromVector<Int64Type>( |
| {1, 2, base + 1, base + 2, base + 3, base + 4, base + 5, base + 6}, &indices); |
| EXPECT_RAISES_WITH_MESSAGE_THAT( |
| IndexError, |
| ::testing::HasSubstr("Some indices were out of bounds: 32769, 32770, 32771, ..."), |
| scanner->TakeRows(*indices)); |
| } |
| } |
| |
| class FailingFragment : public InMemoryFragment { |
| public: |
| using InMemoryFragment::InMemoryFragment; |
| Result<ScanTaskIterator> Scan(std::shared_ptr<ScanOptions> options) override { |
| int index = 0; |
| auto self = shared_from_this(); |
| return MakeFunctionIterator([=]() mutable -> Result<std::shared_ptr<ScanTask>> { |
| if (index > 16) { |
| return Status::Invalid("Oh no, we failed!"); |
| } |
| RecordBatchVector batches = {record_batches_[index++ % record_batches_.size()]}; |
| return std::make_shared<InMemoryScanTask>(batches, options, self); |
| }); |
| } |
| }; |
| |
| TEST_P(TestScanner, ScanBatchesFailure) { |
| SetSchema({field("i32", int32()), field("f64", float64())}); |
| auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); |
| RecordBatchVector batches = {batch, batch, batch, batch}; |
| |
| ScannerBuilder builder(schema_, std::make_shared<FailingFragment>(batches), options_); |
| ASSERT_OK(builder.UseThreads(UseThreads())); |
| ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish()); |
| |
| ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches()); |
| |
| int counter = 0; |
| while (true) { |
| // Make sure we get all batches that were yielded before the failing scan task |
| auto maybe_batch = batch_it.Next(); |
| if (counter++ <= 16) { |
| ASSERT_OK_AND_ASSIGN(auto scanned_batch, maybe_batch); |
| AssertBatchesEqual(*batch, *scanned_batch.record_batch); |
| ASSERT_NE(nullptr, scanned_batch.fragment); |
| } else { |
| EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Oh no, we failed!"), |
| maybe_batch); |
| break; |
| } |
| } |
| } |
| |
| TEST_P(TestScanner, Head) { |
| SetSchema({field("i32", int32()), field("f64", float64())}); |
| auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); |
| |
| auto scanner = MakeScanner(batch); |
| std::shared_ptr<Table> expected, actual; |
| |
| ASSERT_OK_AND_ASSIGN(expected, Table::FromRecordBatches(schema_, {})); |
| ASSERT_OK_AND_ASSIGN(actual, scanner->Head(0)); |
| AssertTablesEqual(*expected, *actual); |
| |
| ASSERT_OK_AND_ASSIGN(expected, Table::FromRecordBatches(schema_, {batch})); |
| ASSERT_OK_AND_ASSIGN(actual, scanner->Head(kBatchSize)); |
| AssertTablesEqual(*expected, *actual); |
| |
| ASSERT_OK_AND_ASSIGN(expected, Table::FromRecordBatches(schema_, {batch->Slice(0, 1)})); |
| ASSERT_OK_AND_ASSIGN(actual, scanner->Head(1)); |
| AssertTablesEqual(*expected, *actual); |
| |
| ASSERT_OK_AND_ASSIGN(expected, |
| Table::FromRecordBatches(schema_, {batch, batch->Slice(0, 1)})); |
| ASSERT_OK_AND_ASSIGN(actual, scanner->Head(kBatchSize + 1)); |
| AssertTablesEqual(*expected, *actual); |
| |
| ASSERT_OK_AND_ASSIGN(expected, scanner->ToTable()); |
| ASSERT_OK_AND_ASSIGN(actual, |
| scanner->Head(kBatchSize * kNumberBatches * kNumberChildDatasets)); |
| AssertTablesEqual(*expected, *actual); |
| |
| ASSERT_OK_AND_ASSIGN(expected, scanner->ToTable()); |
| ASSERT_OK_AND_ASSIGN( |
| actual, scanner->Head(kBatchSize * kNumberBatches * kNumberChildDatasets + 100)); |
| AssertTablesEqual(*expected, *actual); |
| } |
| |
| INSTANTIATE_TEST_SUITE_P(TestScannerThreading, TestScanner, ::testing::Bool()); |
| |
| class TestScannerBuilder : public ::testing::Test { |
| void SetUp() override { |
| DatasetVector sources; |
| |
| schema_ = schema({ |
| field("b", boolean()), |
| field("i8", int8()), |
| field("i16", int16()), |
| field("i32", int32()), |
| field("i64", int64()), |
| }); |
| |
| ASSERT_OK_AND_ASSIGN(dataset_, UnionDataset::Make(schema_, sources)); |
| } |
| |
| protected: |
| std::shared_ptr<ScanOptions> options_ = std::make_shared<ScanOptions>(); |
| std::shared_ptr<Schema> schema_; |
| std::shared_ptr<Dataset> dataset_; |
| }; |
| |
| TEST_F(TestScannerBuilder, TestProject) { |
| ScannerBuilder builder(dataset_, options_); |
| |
| // It is valid to request no columns, e.g. `SELECT 1 FROM t WHERE t.a > 0`. |
| // still needs to touch the `a` column. |
| ASSERT_OK(builder.Project({})); |
| ASSERT_OK(builder.Project({"i64", "b", "i8"})); |
| ASSERT_OK(builder.Project({"i16", "i16"})); |
| ASSERT_OK(builder.Project( |
| {field_ref("i16"), call("multiply", {field_ref("i16"), literal(2)})}, |
| {"i16 renamed", "i16 * 2"})); |
| |
| ASSERT_RAISES(Invalid, builder.Project({"not_found_column"})); |
| ASSERT_RAISES(Invalid, builder.Project({"i8", "not_found_column"})); |
| ASSERT_RAISES(Invalid, |
| builder.Project({field_ref("not_found_column"), |
| call("multiply", {field_ref("i16"), literal(2)})}, |
| {"i16 renamed", "i16 * 2"})); |
| |
| ASSERT_RAISES(NotImplemented, builder.Project({field_ref(FieldRef("nested", "column"))}, |
| {"nested column"})); |
| |
| // provided more field names than column exprs or vice versa |
| ASSERT_RAISES(Invalid, builder.Project({}, {"i16 renamed", "i16 * 2"})); |
| ASSERT_RAISES(Invalid, builder.Project({literal(2), field_ref("a")}, {"a"})); |
| } |
| |
| TEST_F(TestScannerBuilder, TestFilter) { |
| ScannerBuilder builder(dataset_, options_); |
| |
| ASSERT_OK(builder.Filter(literal(true))); |
| ASSERT_OK(builder.Filter(equal(field_ref("i64"), literal<int64_t>(10)))); |
| ASSERT_OK(builder.Filter(or_(equal(field_ref("i64"), literal<int64_t>(10)), |
| equal(field_ref("b"), literal(true))))); |
| |
| ASSERT_OK(builder.Filter(equal(field_ref("i64"), literal<double>(10)))); |
| |
| ASSERT_RAISES(Invalid, builder.Filter(equal(field_ref("not_a_column"), literal(true)))); |
| |
| ASSERT_RAISES( |
| NotImplemented, |
| builder.Filter(equal(field_ref(FieldRef("nested", "column")), literal(true)))); |
| |
| ASSERT_RAISES(Invalid, |
| builder.Filter(or_(equal(field_ref("i64"), literal<int64_t>(10)), |
| equal(field_ref("not_a_column"), literal(true))))); |
| } |
| |
| TEST(ScanOptions, TestMaterializedFields) { |
| auto i32 = field("i32", int32()); |
| auto i64 = field("i64", int64()); |
| auto opts = std::make_shared<ScanOptions>(); |
| |
| // empty dataset, project nothing = nothing materialized |
| opts->dataset_schema = schema({}); |
| ASSERT_OK(SetProjection(opts.get(), {}, {})); |
| EXPECT_THAT(opts->MaterializedFields(), IsEmpty()); |
| |
| // non-empty dataset, project nothing = nothing materialized |
| opts->dataset_schema = schema({i32, i64}); |
| EXPECT_THAT(opts->MaterializedFields(), IsEmpty()); |
| |
| // project nothing, filter on i32 = materialize i32 |
| opts->filter = equal(field_ref("i32"), literal(10)); |
| EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32")); |
| |
| // project i32 & i64, filter nothing = materialize i32 & i64 |
| opts->filter = literal(true); |
| ASSERT_OK(SetProjection(opts.get(), {"i32", "i64"})); |
| EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64")); |
| |
| // project i32 + i64, filter nothing = materialize i32 & i64 |
| opts->filter = literal(true); |
| ASSERT_OK(SetProjection(opts.get(), {call("add", {field_ref("i32"), field_ref("i64")})}, |
| {"i32 + i64"})); |
| EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64")); |
| |
| // project i32, filter nothing = materialize i32 |
| ASSERT_OK(SetProjection(opts.get(), {"i32"})); |
| EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32")); |
| |
| // project i32, filter on i32 = materialize i32 (reported twice) |
| opts->filter = equal(field_ref("i32"), literal(10)); |
| EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i32")); |
| |
| // project i32, filter on i32 & i64 = materialize i64, i32 (reported twice) |
| opts->filter = less(field_ref("i32"), field_ref("i64")); |
| EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64", "i32")); |
| |
| // project i32, filter on i64 = materialize i32 & i64 |
| opts->filter = equal(field_ref("i64"), literal(10)); |
| EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i64", "i32")); |
| } |
| |
| } // namespace dataset |
| } // namespace arrow |