blob: 63087f33cfbcef6486fabbf7f5ac000b83767d7d [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <memory>
#include <vector>
#include "benchmark/benchmark.h"
#include "gandiva/arrow.h"
#include "gandiva/filter.h"
#include "gandiva/projector.h"
#include "gandiva/tests/generate_data.h"
#pragma once
#define THOUSAND (1024)
#define MILLION (1024 * 1024)
#define NUM_BATCHES 16
namespace gandiva {
template <typename C_TYPE>
std::vector<C_TYPE> GenerateData(int num_records, DataGenerator<C_TYPE>& data_generator) {
std::vector<C_TYPE> data;
for (int i = 0; i < num_records; i++) {
data.push_back(data_generator.GenerateData());
}
return data;
}
class BaseEvaluator {
public:
virtual ~BaseEvaluator() = default;
virtual Status Evaluate(arrow::RecordBatch& batch, arrow::MemoryPool* pool) = 0;
};
class ProjectEvaluator : public BaseEvaluator {
public:
explicit ProjectEvaluator(std::shared_ptr<Projector> projector)
: projector_(projector) {}
Status Evaluate(arrow::RecordBatch& batch, arrow::MemoryPool* pool) override {
arrow::ArrayVector outputs;
return projector_->Evaluate(batch, pool, &outputs);
}
private:
std::shared_ptr<Projector> projector_;
};
class FilterEvaluator : public BaseEvaluator {
public:
explicit FilterEvaluator(std::shared_ptr<Filter> filter) : filter_(filter) {}
Status Evaluate(arrow::RecordBatch& batch, arrow::MemoryPool* pool) override {
if (selection_ == nullptr || selection_->GetMaxSlots() < batch.num_rows()) {
auto status = SelectionVector::MakeInt16(batch.num_rows(), pool, &selection_);
if (!status.ok()) {
return status;
}
}
return filter_->Evaluate(batch, selection_);
}
private:
std::shared_ptr<Filter> filter_;
std::shared_ptr<SelectionVector> selection_;
};
template <typename TYPE, typename C_TYPE>
Status TimedEvaluate(SchemaPtr schema, BaseEvaluator& evaluator,
DataGenerator<C_TYPE>& data_generator, arrow::MemoryPool* pool,
int num_records, int batch_size, benchmark::State& state) {
int num_remaining = num_records;
int num_fields = schema->num_fields();
int num_calls = 0;
Status status;
int64_t total_bytes_processed = 0;
int64_t total_items_processed = 0;
// Generate batches of data
std::shared_ptr<arrow::RecordBatch> batches[NUM_BATCHES];
for (int i = 0; i < NUM_BATCHES; i++) {
// generate data for all columns in the schema
std::vector<ArrayPtr> columns;
int64_t batch_bytes = 0;
for (int col = 0; col < num_fields; col++) {
std::vector<C_TYPE> data = GenerateData<C_TYPE>(batch_size, data_generator);
std::vector<bool> validity(batch_size, true);
ArrayPtr col_data =
MakeArrowArray<TYPE, C_TYPE>(schema->field(col)->type(), data, validity);
columns.push_back(col_data);
batch_bytes += data.size() * sizeof(C_TYPE);
}
// make the record batch
std::shared_ptr<arrow::RecordBatch> batch =
arrow::RecordBatch::Make(schema, batch_size, columns);
batches[i] = batch;
total_bytes_processed += batch_bytes;
}
for (auto _ : state) {
int num_in_batch = batch_size;
num_remaining = num_records;
while (num_remaining > 0) {
if (batch_size > num_remaining) {
num_in_batch = num_remaining;
}
status = evaluator.Evaluate(*(batches[num_calls % NUM_BATCHES]), pool);
if (!status.ok()) {
state.SkipWithError("Evaluation of the batch failed");
return status;
}
num_calls++;
num_remaining -= num_in_batch;
total_items_processed += num_in_batch;
}
}
state.SetBytesProcessed(total_bytes_processed);
state.SetItemsProcessed(total_items_processed);
return Status::OK();
}
} // namespace gandiva