blob: d1d0e45b8275d590001bca2e92781793881c2164 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "arrow/dataset/scanner.h"
#include <memory>
#include "arrow/dataset/scanner_internal.h"
#include "arrow/dataset/test_util.h"
#include "arrow/record_batch.h"
#include "arrow/table.h"
#include "arrow/testing/generator.h"
#include "arrow/testing/util.h"
using testing::ElementsAre;
using testing::IsEmpty;
namespace arrow {
namespace dataset {
constexpr int64_t kNumberChildDatasets = 2;
constexpr int64_t kNumberBatches = 16;
constexpr int64_t kBatchSize = 1024;
class TestScanner : public DatasetFixtureMixin {
protected:
std::shared_ptr<Scanner> MakeScanner(std::shared_ptr<RecordBatch> batch) {
std::vector<std::shared_ptr<RecordBatch>> batches{static_cast<size_t>(kNumberBatches),
batch};
DatasetVector children{static_cast<size_t>(kNumberChildDatasets),
std::make_shared<InMemoryDataset>(batch->schema(), batches)};
EXPECT_OK_AND_ASSIGN(auto dataset, UnionDataset::Make(batch->schema(), children));
ScannerBuilder builder(dataset, options_);
EXPECT_OK_AND_ASSIGN(auto scanner, builder.Finish());
return scanner;
}
void AssertScannerEqualsRepetitionsOf(
std::shared_ptr<Scanner> scanner, std::shared_ptr<RecordBatch> batch,
const int64_t total_batches = kNumberChildDatasets * kNumberBatches) {
auto expected = ConstantArrayGenerator::Repeat(total_batches, batch);
// Verifies that the unified BatchReader is equivalent to flattening all the
// structures of the scanner, i.e. Scanner[Dataset[ScanTask[RecordBatch]]]
AssertScannerEquals(expected.get(), scanner.get());
}
void AssertScanBatchesEqualRepetitionsOf(
std::shared_ptr<Scanner> scanner, std::shared_ptr<RecordBatch> batch,
const int64_t total_batches = kNumberChildDatasets * kNumberBatches) {
auto expected = ConstantArrayGenerator::Repeat(total_batches, batch);
AssertScanBatchesEquals(expected.get(), scanner.get());
}
void AssertScanBatchesUnorderedEqualRepetitionsOf(
std::shared_ptr<Scanner> scanner, std::shared_ptr<RecordBatch> batch,
const int64_t total_batches = kNumberChildDatasets * kNumberBatches) {
auto expected = ConstantArrayGenerator::Repeat(total_batches, batch);
AssertScanBatchesUnorderedEquals(expected.get(), scanner.get());
}
};
TEST_F(TestScanner, Scan) {
SetSchema({field("i32", int32()), field("f64", float64())});
auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
AssertScannerEqualsRepetitionsOf(MakeScanner(batch), batch);
}
TEST_F(TestScanner, ScanBatches) {
SetSchema({field("i32", int32()), field("f64", float64())});
auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
AssertScanBatchesEqualRepetitionsOf(MakeScanner(batch), batch);
}
TEST_F(TestScanner, ScanBatchesUnordered) {
SetSchema({field("i32", int32()), field("f64", float64())});
auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
AssertScanBatchesUnorderedEqualRepetitionsOf(MakeScanner(batch), batch);
}
TEST_F(TestScanner, ScanWithCappedBatchSize) {
SetSchema({field("i32", int32()), field("f64", float64())});
auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
options_->batch_size = kBatchSize / 2;
auto expected = batch->Slice(kBatchSize / 2);
AssertScannerEqualsRepetitionsOf(MakeScanner(batch), expected,
kNumberChildDatasets * kNumberBatches * 2);
}
TEST_F(TestScanner, FilteredScan) {
SetSchema({field("f64", float64())});
double value = 0.5;
ASSERT_OK_AND_ASSIGN(auto f64,
ArrayFromBuilderVisitor(float64(), kBatchSize, kBatchSize / 2,
[&](DoubleBuilder* builder) {
builder->UnsafeAppend(value);
builder->UnsafeAppend(-value);
value += 1.0;
}));
SetFilter(greater(field_ref("f64"), literal(0.0)));
auto batch = RecordBatch::Make(schema_, f64->length(), {f64});
value = 0.5;
ASSERT_OK_AND_ASSIGN(
auto f64_filtered,
ArrayFromBuilderVisitor(float64(), kBatchSize / 2, [&](DoubleBuilder* builder) {
builder->UnsafeAppend(value);
value += 1.0;
}));
auto filtered_batch =
RecordBatch::Make(schema_, f64_filtered->length(), {f64_filtered});
AssertScannerEqualsRepetitionsOf(MakeScanner(batch), filtered_batch);
}
TEST_F(TestScanner, MaterializeMissingColumn) {
SetSchema({field("i32", int32()), field("f64", float64())});
auto batch_missing_f64 =
ConstantArrayGenerator::Zeroes(kBatchSize, schema({field("i32", int32())}));
auto fragment_missing_f64 = std::make_shared<InMemoryFragment>(
RecordBatchVector{static_cast<size_t>(kNumberChildDatasets * kNumberBatches),
batch_missing_f64},
equal(field_ref("f64"), literal(2.5)));
ASSERT_OK_AND_ASSIGN(auto f64, ArrayFromBuilderVisitor(float64(), kBatchSize,
[&](DoubleBuilder* builder) {
builder->UnsafeAppend(2.5);
}));
auto batch_with_f64 =
RecordBatch::Make(schema_, f64->length(), {batch_missing_f64->column(0), f64});
ScannerBuilder builder{schema_, fragment_missing_f64, options_};
ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());
AssertScannerEqualsRepetitionsOf(scanner, batch_with_f64);
}
TEST_F(TestScanner, ToTable) {
SetSchema({field("i32", int32()), field("f64", float64())});
auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
std::vector<std::shared_ptr<RecordBatch>> batches{kNumberBatches * kNumberChildDatasets,
batch};
ASSERT_OK_AND_ASSIGN(auto expected, Table::FromRecordBatches(batches));
auto scanner = MakeScanner(batch);
std::shared_ptr<Table> actual;
options_->use_threads = false;
ASSERT_OK_AND_ASSIGN(actual, scanner->ToTable());
AssertTablesEqual(*expected, *actual);
// There is no guarantee on the ordering when using multiple threads, but
// since the RecordBatch is always the same it will pass.
options_->use_threads = true;
ASSERT_OK_AND_ASSIGN(actual, scanner->ToTable());
AssertTablesEqual(*expected, *actual);
}
class TestScannerBuilder : public ::testing::Test {
void SetUp() override {
DatasetVector sources;
schema_ = schema({
field("b", boolean()),
field("i8", int8()),
field("i16", int16()),
field("i32", int32()),
field("i64", int64()),
});
ASSERT_OK_AND_ASSIGN(dataset_, UnionDataset::Make(schema_, sources));
}
protected:
std::shared_ptr<ScanOptions> options_ = std::make_shared<ScanOptions>();
std::shared_ptr<Schema> schema_;
std::shared_ptr<Dataset> dataset_;
};
TEST_F(TestScannerBuilder, TestProject) {
ScannerBuilder builder(dataset_, options_);
// It is valid to request no columns, e.g. `SELECT 1 FROM t WHERE t.a > 0`.
// still needs to touch the `a` column.
ASSERT_OK(builder.Project({}));
ASSERT_OK(builder.Project({"i64", "b", "i8"}));
ASSERT_OK(builder.Project({"i16", "i16"}));
ASSERT_OK(builder.Project(
{field_ref("i16"), call("multiply", {field_ref("i16"), literal(2)})},
{"i16 renamed", "i16 * 2"}));
ASSERT_RAISES(Invalid, builder.Project({"not_found_column"}));
ASSERT_RAISES(Invalid, builder.Project({"i8", "not_found_column"}));
ASSERT_RAISES(Invalid,
builder.Project({field_ref("not_found_column"),
call("multiply", {field_ref("i16"), literal(2)})},
{"i16 renamed", "i16 * 2"}));
ASSERT_RAISES(NotImplemented, builder.Project({field_ref(FieldRef("nested", "column"))},
{"nested column"}));
// provided more field names than column exprs or vice versa
ASSERT_RAISES(Invalid, builder.Project({}, {"i16 renamed", "i16 * 2"}));
ASSERT_RAISES(Invalid, builder.Project({literal(2), field_ref("a")}, {"a"}));
}
TEST_F(TestScannerBuilder, TestFilter) {
ScannerBuilder builder(dataset_, options_);
ASSERT_OK(builder.Filter(literal(true)));
ASSERT_OK(builder.Filter(equal(field_ref("i64"), literal<int64_t>(10))));
ASSERT_OK(builder.Filter(or_(equal(field_ref("i64"), literal<int64_t>(10)),
equal(field_ref("b"), literal(true)))));
ASSERT_OK(builder.Filter(equal(field_ref("i64"), literal<double>(10))));
ASSERT_RAISES(Invalid, builder.Filter(equal(field_ref("not_a_column"), literal(true))));
ASSERT_RAISES(
NotImplemented,
builder.Filter(equal(field_ref(FieldRef("nested", "column")), literal(true))));
ASSERT_RAISES(Invalid,
builder.Filter(or_(equal(field_ref("i64"), literal<int64_t>(10)),
equal(field_ref("not_a_column"), literal(true)))));
}
TEST(ScanOptions, TestMaterializedFields) {
auto i32 = field("i32", int32());
auto i64 = field("i64", int64());
auto opts = std::make_shared<ScanOptions>();
// empty dataset, project nothing = nothing materialized
opts->dataset_schema = schema({});
ASSERT_OK(SetProjection(opts.get(), {}, {}));
EXPECT_THAT(opts->MaterializedFields(), IsEmpty());
// non-empty dataset, project nothing = nothing materialized
opts->dataset_schema = schema({i32, i64});
EXPECT_THAT(opts->MaterializedFields(), IsEmpty());
// project nothing, filter on i32 = materialize i32
opts->filter = equal(field_ref("i32"), literal(10));
EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32"));
// project i32 & i64, filter nothing = materialize i32 & i64
opts->filter = literal(true);
ASSERT_OK(SetProjection(opts.get(), {"i32", "i64"}));
EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64"));
// project i32 + i64, filter nothing = materialize i32 & i64
opts->filter = literal(true);
ASSERT_OK(SetProjection(opts.get(), {call("add", {field_ref("i32"), field_ref("i64")})},
{"i32 + i64"}));
EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64"));
// project i32, filter nothing = materialize i32
ASSERT_OK(SetProjection(opts.get(), {"i32"}));
EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32"));
// project i32, filter on i32 = materialize i32 (reported twice)
opts->filter = equal(field_ref("i32"), literal(10));
EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i32"));
// project i32, filter on i32 & i64 = materialize i64, i32 (reported twice)
opts->filter = less(field_ref("i32"), field_ref("i64"));
EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64", "i32"));
// project i32, filter on i64 = materialize i32 & i64
opts->filter = equal(field_ref("i64"), literal(10));
EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i64", "i32"));
}
} // namespace dataset
} // namespace arrow