blob: 0ffe53578c4b693bf2913fc5240827dd8f67c36a [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "./arrow_types.h"
#if defined(ARROW_R_WITH_ARROW)
#include <arrow/compute/api.h>
#include <arrow/record_batch.h>
#include <arrow/table.h>
std::shared_ptr<arrow::compute::CastOptions> make_cast_options(cpp11::list options);
arrow::compute::ExecContext* gc_context() {
static arrow::compute::ExecContext context(gc_memory_pool());
return &context;
}
// [[arrow::export]]
std::shared_ptr<arrow::RecordBatch> RecordBatch__cast(
const std::shared_ptr<arrow::RecordBatch>& batch,
const std::shared_ptr<arrow::Schema>& schema, cpp11::list options) {
auto opts = make_cast_options(options);
auto nc = batch->num_columns();
arrow::ArrayVector columns(nc);
for (int i = 0; i < nc; i++) {
columns[i] = ValueOrStop(
arrow::compute::Cast(*batch->column(i), schema->field(i)->type(), *opts));
}
return arrow::RecordBatch::Make(schema, batch->num_rows(), std::move(columns));
}
// [[arrow::export]]
std::shared_ptr<arrow::Table> Table__cast(const std::shared_ptr<arrow::Table>& table,
const std::shared_ptr<arrow::Schema>& schema,
cpp11::list options) {
auto opts = make_cast_options(options);
auto nc = table->num_columns();
using ColumnVector = std::vector<std::shared_ptr<arrow::ChunkedArray>>;
ColumnVector columns(nc);
for (int i = 0; i < nc; i++) {
arrow::Datum value(table->column(i));
arrow::Datum out =
ValueOrStop(arrow::compute::Cast(value, schema->field(i)->type(), *opts));
columns[i] = out.chunked_array();
}
return arrow::Table::Make(schema, std::move(columns), table->num_rows());
}
template <typename T>
std::shared_ptr<T> MaybeUnbox(const char* class_name, SEXP x) {
if (Rf_inherits(x, "ArrowObject") && Rf_inherits(x, class_name)) {
return cpp11::as_cpp<std::shared_ptr<T>>(x);
}
return nullptr;
}
namespace cpp11 {
template <>
arrow::Datum as_cpp<arrow::Datum>(SEXP x) {
if (auto array = MaybeUnbox<arrow::Array>("Array", x)) {
return array;
}
if (auto chunked_array = MaybeUnbox<arrow::ChunkedArray>("ChunkedArray", x)) {
return chunked_array;
}
if (auto batch = MaybeUnbox<arrow::RecordBatch>("RecordBatch", x)) {
return batch;
}
if (auto table = MaybeUnbox<arrow::Table>("Table", x)) {
return table;
}
if (auto scalar = MaybeUnbox<arrow::Scalar>("Scalar", x)) {
return scalar;
}
// This assumes that R objects have already been converted to Arrow objects;
// that seems right but should we do the wrapping here too/instead?
cpp11::stop("to_datum: Not implemented for type %s", Rf_type2char(TYPEOF(x)));
}
} // namespace cpp11
SEXP from_datum(arrow::Datum datum) {
switch (datum.kind()) {
case arrow::Datum::SCALAR:
return cpp11::to_r6(datum.scalar());
case arrow::Datum::ARRAY:
return cpp11::to_r6(datum.make_array());
case arrow::Datum::CHUNKED_ARRAY:
return cpp11::to_r6(datum.chunked_array());
case arrow::Datum::RECORD_BATCH:
return cpp11::to_r6(datum.record_batch());
case arrow::Datum::TABLE:
return cpp11::to_r6(datum.table());
default:
break;
}
cpp11::stop("from_datum: Not implemented for Datum %s", datum.ToString().c_str());
}
std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
std::string func_name, cpp11::list options) {
if (func_name == "filter") {
using Options = arrow::compute::FilterOptions;
auto out = std::make_shared<Options>(Options::Defaults());
SEXP keep_na = options["keep_na"];
if (!Rf_isNull(keep_na) && cpp11::as_cpp<bool>(keep_na)) {
out->null_selection_behavior = Options::EMIT_NULL;
}
return out;
}
if (func_name == "take") {
using Options = arrow::compute::TakeOptions;
auto out = std::make_shared<Options>(Options::Defaults());
return out;
}
if (func_name == "array_sort_indices") {
using Order = arrow::compute::SortOrder;
using Options = arrow::compute::ArraySortOptions;
// false means descending, true means ascending
auto order = cpp11::as_cpp<bool>(options["order"]);
auto out =
std::make_shared<Options>(Options(order ? Order::Descending : Order::Ascending));
return out;
}
if (func_name == "sort_indices") {
using Key = arrow::compute::SortKey;
using Order = arrow::compute::SortOrder;
using Options = arrow::compute::SortOptions;
auto names = cpp11::as_cpp<std::vector<std::string>>(options["names"]);
// false means descending, true means ascending
// cpp11 does not support bool here so use int
auto orders = cpp11::as_cpp<std::vector<int>>(options["orders"]);
std::vector<Key> keys;
for (size_t i = 0; i < names.size(); i++) {
keys.push_back(
Key(names[i], (orders[i] > 0) ? Order::Descending : Order::Ascending));
}
auto out = std::make_shared<Options>(Options(keys));
return out;
}
if (func_name == "min_max") {
using Options = arrow::compute::MinMaxOptions;
auto out = std::make_shared<Options>(Options::Defaults());
out->null_handling =
cpp11::as_cpp<bool>(options["na.rm"]) ? Options::SKIP : Options::EMIT_NULL;
return out;
}
if (func_name == "quantile") {
using Options = arrow::compute::QuantileOptions;
auto out = std::make_shared<Options>(Options::Defaults());
SEXP q = options["q"];
if (!Rf_isNull(q) && TYPEOF(q) == REALSXP) {
out->q = cpp11::as_cpp<std::vector<double>>(q);
}
SEXP interpolation = options["interpolation"];
if (!Rf_isNull(interpolation) && TYPEOF(interpolation) == INTSXP &&
XLENGTH(interpolation) == 1) {
out->interpolation =
cpp11::as_cpp<enum arrow::compute::QuantileOptions::Interpolation>(
interpolation);
}
return out;
}
if (func_name == "is_in" || func_name == "index_in") {
using Options = arrow::compute::SetLookupOptions;
return std::make_shared<Options>(cpp11::as_cpp<arrow::Datum>(options["value_set"]),
cpp11::as_cpp<bool>(options["skip_nulls"]));
}
if (func_name == "dictionary_encode") {
using Options = arrow::compute::DictionaryEncodeOptions;
auto out = std::make_shared<Options>(Options::Defaults());
if (!Rf_isNull(options["null_encoding_behavior"])) {
out->null_encoding_behavior = cpp11::as_cpp<
enum arrow::compute::DictionaryEncodeOptions::NullEncodingBehavior>(
options["null_encoding_behavior"]);
}
return out;
}
if (func_name == "cast") {
return make_cast_options(options);
}
if (func_name == "match_substring" || func_name == "match_substring_regex") {
using Options = arrow::compute::MatchSubstringOptions;
return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["pattern"]));
}
if (func_name == "replace_substring" || func_name == "replace_substring_regex") {
using Options = arrow::compute::ReplaceSubstringOptions;
int64_t max_replacements = -1;
if (!Rf_isNull(options["max_replacements"])) {
max_replacements = cpp11::as_cpp<int64_t>(options["max_replacements"]);
}
return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["pattern"]),
cpp11::as_cpp<std::string>(options["replacement"]),
max_replacements);
}
if (func_name == "split_pattern") {
using Options = arrow::compute::SplitPatternOptions;
int64_t max_splits = -1;
if (!Rf_isNull(options["max_splits"])) {
max_splits = cpp11::as_cpp<int64_t>(options["max_splits"]);
}
bool reverse = false;
if (!Rf_isNull(options["reverse"])) {
reverse = cpp11::as_cpp<bool>(options["reverse"]);
}
return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["pattern"]),
max_splits, reverse);
}
if (func_name == "utf8_split_whitespace" || func_name == "ascii_split_whitespace") {
using Options = arrow::compute::SplitOptions;
int64_t max_splits = -1;
if (!Rf_isNull(options["max_splits"])) {
max_splits = cpp11::as_cpp<int64_t>(options["max_splits"]);
}
bool reverse = false;
if (!Rf_isNull(options["reverse"])) {
reverse = cpp11::as_cpp<bool>(options["reverse"]);
}
return std::make_shared<Options>(max_splits, reverse);
}
if (func_name == "variance" || func_name == "stddev") {
using Options = arrow::compute::VarianceOptions;
return std::make_shared<Options>(cpp11::as_cpp<int64_t>(options["ddof"]));
}
return nullptr;
}
std::shared_ptr<arrow::compute::CastOptions> make_cast_options(cpp11::list options) {
using Options = arrow::compute::CastOptions;
auto out = std::make_shared<Options>(true);
SEXP to_type = options["to_type"];
if (!Rf_isNull(to_type) && cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(to_type)) {
out->to_type = cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(to_type);
}
SEXP allow_float_truncate = options["allow_float_truncate"];
if (!Rf_isNull(allow_float_truncate) && cpp11::as_cpp<bool>(allow_float_truncate)) {
out->allow_float_truncate = cpp11::as_cpp<bool>(allow_float_truncate);
}
SEXP allow_time_truncate = options["allow_time_truncate"];
if (!Rf_isNull(allow_time_truncate) && cpp11::as_cpp<bool>(allow_time_truncate)) {
out->allow_time_truncate = cpp11::as_cpp<bool>(allow_time_truncate);
}
SEXP allow_int_overflow = options["allow_int_overflow"];
if (!Rf_isNull(allow_int_overflow) && cpp11::as_cpp<bool>(allow_int_overflow)) {
out->allow_int_overflow = cpp11::as_cpp<bool>(allow_int_overflow);
}
return out;
}
// [[arrow::export]]
SEXP compute__CallFunction(std::string func_name, cpp11::list args, cpp11::list options) {
auto opts = make_compute_options(func_name, options);
auto datum_args = arrow::r::from_r_list<arrow::Datum>(args);
auto out = ValueOrStop(
arrow::compute::CallFunction(func_name, datum_args, opts.get(), gc_context()));
return from_datum(std::move(out));
}
// [[arrow::export]]
SEXP compute__GroupBy(cpp11::list arguments, cpp11::list keys, cpp11::list options) {
// options is a list of pairs: string function name, list of options
std::vector<std::shared_ptr<arrow::compute::FunctionOptions>> keep_alives;
std::vector<arrow::compute::internal::Aggregate> aggregates;
for (cpp11::list name_opts : options) {
auto name = cpp11::as_cpp<std::string>(name_opts[0]);
auto opts = make_compute_options(name, name_opts[1]);
aggregates.push_back(
arrow::compute::internal::Aggregate{std::move(name), opts.get()});
keep_alives.push_back(std::move(opts));
}
auto datum_arguments = arrow::r::from_r_list<arrow::Datum>(arguments);
auto datum_keys = arrow::r::from_r_list<arrow::Datum>(keys);
auto out = ValueOrStop(arrow::compute::internal::GroupBy(datum_arguments, datum_keys,
aggregates, gc_context()));
return from_datum(std::move(out));
}
// [[arrow::export]]
std::vector<std::string> compute__GetFunctionNames() {
return arrow::compute::GetFunctionRegistry()->GetFunctionNames();
}
#endif