blob: 578ce74d05d12fb792b837cbd12c78a325c6b07a [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "./arrow_types.h"
#include "./safe-call-into-r.h"
#include <arrow/array/util.h>
#include <arrow/compute/api.h>
#include <arrow/record_batch.h>
#include <arrow/table.h>
std::shared_ptr<arrow::compute::CastOptions> make_cast_options(cpp11::list options);
arrow::compute::ExecContext* gc_context() {
static arrow::compute::ExecContext context(gc_memory_pool());
return &context;
}
// [[arrow::export]]
std::shared_ptr<arrow::RecordBatch> RecordBatch__cast(
const std::shared_ptr<arrow::RecordBatch>& batch,
const std::shared_ptr<arrow::Schema>& schema, cpp11::list options) {
auto opts = make_cast_options(options);
auto nc = batch->num_columns();
arrow::ArrayVector columns(nc);
for (int i = 0; i < nc; i++) {
columns[i] = ValueOrStop(
arrow::compute::Cast(*batch->column(i), schema->field(i)->type(), *opts));
}
return arrow::RecordBatch::Make(schema, batch->num_rows(), std::move(columns));
}
// [[arrow::export]]
std::shared_ptr<arrow::Table> Table__cast(const std::shared_ptr<arrow::Table>& table,
const std::shared_ptr<arrow::Schema>& schema,
cpp11::list options) {
auto opts = make_cast_options(options);
auto nc = table->num_columns();
using ColumnVector = std::vector<std::shared_ptr<arrow::ChunkedArray>>;
ColumnVector columns(nc);
for (int i = 0; i < nc; i++) {
arrow::Datum value(table->column(i));
arrow::Datum out =
ValueOrStop(arrow::compute::Cast(value, schema->field(i)->type(), *opts));
columns[i] = out.chunked_array();
}
return arrow::Table::Make(schema, std::move(columns), table->num_rows());
}
template <typename T>
std::shared_ptr<T> MaybeUnbox(const char* class_name, SEXP x) {
if (Rf_inherits(x, "ArrowObject") && Rf_inherits(x, class_name)) {
return cpp11::as_cpp<std::shared_ptr<T>>(x);
}
return nullptr;
}
namespace cpp11 {
template <>
arrow::Datum as_cpp<arrow::Datum>(SEXP x) {
if (auto array = MaybeUnbox<arrow::Array>("Array", x)) {
return array;
}
if (auto chunked_array = MaybeUnbox<arrow::ChunkedArray>("ChunkedArray", x)) {
return chunked_array;
}
if (auto batch = MaybeUnbox<arrow::RecordBatch>("RecordBatch", x)) {
return batch;
}
if (auto table = MaybeUnbox<arrow::Table>("Table", x)) {
return table;
}
if (auto scalar = MaybeUnbox<arrow::Scalar>("Scalar", x)) {
return scalar;
}
// This assumes that R objects have already been converted to Arrow objects;
// that seems right but should we do the wrapping here too/instead?
cpp11::stop("to_datum: Not implemented for type %s", Rf_type2char(TYPEOF(x)));
}
} // namespace cpp11
SEXP from_datum(arrow::Datum datum) {
switch (datum.kind()) {
case arrow::Datum::SCALAR:
return cpp11::to_r6(datum.scalar());
case arrow::Datum::ARRAY:
return cpp11::to_r6(datum.make_array());
case arrow::Datum::CHUNKED_ARRAY:
return cpp11::to_r6(datum.chunked_array());
case arrow::Datum::RECORD_BATCH:
return cpp11::to_r6(datum.record_batch());
case arrow::Datum::TABLE:
return cpp11::to_r6(datum.table());
default:
break;
}
cpp11::stop("from_datum: Not implemented for Datum %s", datum.ToString().c_str());
}
std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
std::string func_name, cpp11::list options) {
if (func_name == "filter") {
using Options = arrow::compute::FilterOptions;
auto out = std::make_shared<Options>(Options::Defaults());
SEXP keep_na = options["keep_na"];
if (!Rf_isNull(keep_na) && cpp11::as_cpp<bool>(keep_na)) {
out->null_selection_behavior = Options::EMIT_NULL;
}
return out;
}
if (func_name == "take") {
using Options = arrow::compute::TakeOptions;
auto out = std::make_shared<Options>(Options::Defaults());
return out;
}
if (func_name == "array_sort_indices") {
using Order = arrow::compute::SortOrder;
using Options = arrow::compute::ArraySortOptions;
// false means descending, true means ascending
auto order = cpp11::as_cpp<bool>(options["order"]);
auto out =
std::make_shared<Options>(Options(order ? Order::Descending : Order::Ascending));
return out;
}
if (func_name == "sort_indices") {
using Key = arrow::compute::SortKey;
using Order = arrow::compute::SortOrder;
using Options = arrow::compute::SortOptions;
auto names = cpp11::as_cpp<std::vector<std::string>>(options["names"]);
// false means descending, true means ascending
// cpp11 does not support bool here so use int
auto orders = cpp11::as_cpp<std::vector<int>>(options["orders"]);
std::vector<Key> keys;
for (size_t i = 0; i < names.size(); i++) {
keys.push_back(
Key(names[i], (orders[i] > 0) ? Order::Descending : Order::Ascending));
}
auto out = std::make_shared<Options>(Options(keys));
return out;
}
if (func_name == "all" || func_name == "hash_all" || func_name == "any" ||
func_name == "hash_any" || func_name == "approximate_median" ||
func_name == "hash_approximate_median" || func_name == "mean" ||
func_name == "hash_mean" || func_name == "min_max" || func_name == "hash_min_max" ||
func_name == "min" || func_name == "hash_min" || func_name == "max" ||
func_name == "hash_max" || func_name == "sum" || func_name == "hash_sum") {
using Options = arrow::compute::ScalarAggregateOptions;
auto out = std::make_shared<Options>(Options::Defaults());
if (!Rf_isNull(options["min_count"])) {
out->min_count = cpp11::as_cpp<int>(options["min_count"]);
}
if (!Rf_isNull(options["skip_nulls"])) {
out->skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]);
}
return out;
}
if (func_name == "tdigest" || func_name == "hash_tdigest") {
using Options = arrow::compute::TDigestOptions;
auto out = std::make_shared<Options>(Options::Defaults());
if (!Rf_isNull(options["q"])) {
out->q = cpp11::as_cpp<std::vector<double>>(options["q"]);
}
if (!Rf_isNull(options["skip_nulls"])) {
out->skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]);
}
return out;
}
if (func_name == "count") {
using Options = arrow::compute::CountOptions;
auto out = std::make_shared<Options>(Options::Defaults());
out->mode =
cpp11::as_cpp<bool>(options["na.rm"]) ? Options::ONLY_VALID : Options::ONLY_NULL;
return out;
}
if (func_name == "count_distinct" || func_name == "hash_count_distinct") {
using Options = arrow::compute::CountOptions;
auto out = std::make_shared<Options>(Options::Defaults());
out->mode =
cpp11::as_cpp<bool>(options["na.rm"]) ? Options::ONLY_VALID : Options::ALL;
return out;
}
if (func_name == "min_element_wise" || func_name == "max_element_wise") {
using Options = arrow::compute::ElementWiseAggregateOptions;
bool skip_nulls = true;
if (!Rf_isNull(options["skip_nulls"])) {
skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]);
}
return std::make_shared<Options>(skip_nulls);
}
if (func_name == "quantile") {
using Options = arrow::compute::QuantileOptions;
auto out = std::make_shared<Options>(Options::Defaults());
SEXP q = options["q"];
if (!Rf_isNull(q) && TYPEOF(q) == REALSXP) {
out->q = cpp11::as_cpp<std::vector<double>>(q);
}
SEXP interpolation = options["interpolation"];
if (!Rf_isNull(interpolation) && TYPEOF(interpolation) == INTSXP &&
XLENGTH(interpolation) == 1) {
out->interpolation =
cpp11::as_cpp<enum arrow::compute::QuantileOptions::Interpolation>(
interpolation);
}
if (!Rf_isNull(options["min_count"])) {
out->min_count = cpp11::as_cpp<int64_t>(options["min_count"]);
}
if (!Rf_isNull(options["skip_nulls"])) {
out->skip_nulls = cpp11::as_cpp<int64_t>(options["skip_nulls"]);
}
return out;
}
if (func_name == "is_in" || func_name == "index_in") {
using Options = arrow::compute::SetLookupOptions;
return std::make_shared<Options>(cpp11::as_cpp<arrow::Datum>(options["value_set"]),
cpp11::as_cpp<bool>(options["skip_nulls"]));
}
if (func_name == "index") {
using Options = arrow::compute::IndexOptions;
return std::make_shared<Options>(
cpp11::as_cpp<std::shared_ptr<arrow::Scalar>>(options["value"]));
}
if (func_name == "is_null") {
using Options = arrow::compute::NullOptions;
auto out = std::make_shared<Options>(Options::Defaults());
if (!Rf_isNull(options["nan_is_null"])) {
out->nan_is_null = cpp11::as_cpp<bool>(options["nan_is_null"]);
}
return out;
}
if (func_name == "dictionary_encode") {
using Options = arrow::compute::DictionaryEncodeOptions;
auto out = std::make_shared<Options>(Options::Defaults());
if (!Rf_isNull(options["null_encoding_behavior"])) {
out->null_encoding_behavior = cpp11::as_cpp<
enum arrow::compute::DictionaryEncodeOptions::NullEncodingBehavior>(
options["null_encoding_behavior"]);
}
return out;
}
if (func_name == "cast") {
return make_cast_options(options);
}
if (func_name == "binary_join_element_wise") {
using Options = arrow::compute::JoinOptions;
auto out = std::make_shared<Options>(Options::Defaults());
if (!Rf_isNull(options["null_handling"])) {
out->null_handling =
cpp11::as_cpp<enum arrow::compute::JoinOptions::NullHandlingBehavior>(
options["null_handling"]);
}
if (!Rf_isNull(options["null_replacement"])) {
out->null_replacement = cpp11::as_cpp<std::string>(options["null_replacement"]);
}
return out;
}
if (func_name == "make_struct") {
using Options = arrow::compute::MakeStructOptions;
// TODO (ARROW-13371): accept `field_nullability` and `field_metadata` options
return std::make_shared<Options>(
cpp11::as_cpp<std::vector<std::string>>(options["field_names"]));
}
if (func_name == "match_substring" || func_name == "match_substring_regex" ||
func_name == "find_substring" || func_name == "find_substring_regex" ||
func_name == "match_like" || func_name == "starts_with" ||
func_name == "ends_with" || func_name == "count_substring" ||
func_name == "count_substring_regex") {
using Options = arrow::compute::MatchSubstringOptions;
bool ignore_case = false;
if (!Rf_isNull(options["ignore_case"])) {
ignore_case = cpp11::as_cpp<bool>(options["ignore_case"]);
}
return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["pattern"]),
ignore_case);
}
if (func_name == "replace_substring" || func_name == "replace_substring_regex") {
using Options = arrow::compute::ReplaceSubstringOptions;
int64_t max_replacements = -1;
if (!Rf_isNull(options["max_replacements"])) {
max_replacements = cpp11::as_cpp<int64_t>(options["max_replacements"]);
}
return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["pattern"]),
cpp11::as_cpp<std::string>(options["replacement"]),
max_replacements);
}
if (func_name == "extract_regex") {
using Options = arrow::compute::ExtractRegexOptions;
return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["pattern"]));
}
if (func_name == "day_of_week") {
using Options = arrow::compute::DayOfWeekOptions;
bool count_from_zero = false;
if (!Rf_isNull(options["count_from_zero"])) {
count_from_zero = cpp11::as_cpp<bool>(options["count_from_zero"]);
}
return std::make_shared<Options>(count_from_zero,
cpp11::as_cpp<uint32_t>(options["week_start"]));
}
if (func_name == "iso_week") {
return std::make_shared<arrow::compute::WeekOptions>(
arrow::compute::WeekOptions::ISODefaults());
}
if (func_name == "us_week") {
return std::make_shared<arrow::compute::WeekOptions>(
arrow::compute::WeekOptions::USDefaults());
}
if (func_name == "week") {
using Options = arrow::compute::WeekOptions;
bool week_starts_monday = true;
bool count_from_zero = false;
bool first_week_is_fully_in_year = false;
if (!Rf_isNull(options["week_starts_monday"])) {
week_starts_monday = cpp11::as_cpp<bool>(options["week_starts_monday"]);
}
if (!Rf_isNull(options["count_from_zero"])) {
count_from_zero = cpp11::as_cpp<bool>(options["count_from_zero"]);
}
if (!Rf_isNull(options["first_week_is_fully_in_year"])) {
count_from_zero = cpp11::as_cpp<bool>(options["first_week_is_fully_in_year"]);
}
return std::make_shared<Options>(week_starts_monday, count_from_zero,
first_week_is_fully_in_year);
}
if (func_name == "strptime") {
using Options = arrow::compute::StrptimeOptions;
bool error_is_null = false;
if (!Rf_isNull(options["error_is_null"])) {
error_is_null = cpp11::as_cpp<bool>(options["error_is_null"]);
}
return std::make_shared<Options>(
cpp11::as_cpp<std::string>(options["format"]),
cpp11::as_cpp<arrow::TimeUnit::type>(options["unit"]), error_is_null);
}
if (func_name == "strftime") {
using Options = arrow::compute::StrftimeOptions;
return std::make_shared<Options>(
Options(cpp11::as_cpp<std::string>(options["format"]),
cpp11::as_cpp<std::string>(options["locale"])));
}
if (func_name == "assume_timezone") {
using Options = arrow::compute::AssumeTimezoneOptions;
enum Options::Ambiguous ambiguous = Options::AMBIGUOUS_RAISE;
enum Options::Nonexistent nonexistent = Options::NONEXISTENT_RAISE;
if (!Rf_isNull(options["ambiguous"])) {
ambiguous = cpp11::as_cpp<enum Options::Ambiguous>(options["ambiguous"]);
}
if (!Rf_isNull(options["nonexistent"])) {
nonexistent = cpp11::as_cpp<enum Options::Nonexistent>(options["nonexistent"]);
}
return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["timezone"]),
ambiguous, nonexistent);
}
if (func_name == "split_pattern" || func_name == "split_pattern_regex") {
using Options = arrow::compute::SplitPatternOptions;
int64_t max_splits = -1;
if (!Rf_isNull(options["max_splits"])) {
max_splits = cpp11::as_cpp<int64_t>(options["max_splits"]);
}
bool reverse = false;
if (!Rf_isNull(options["reverse"])) {
reverse = cpp11::as_cpp<bool>(options["reverse"]);
}
return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["pattern"]),
max_splits, reverse);
}
if (func_name == "utf8_lpad" || func_name == "utf8_rpad" ||
func_name == "utf8_center" || func_name == "ascii_lpad" ||
func_name == "ascii_rpad" || func_name == "ascii_center") {
using Options = arrow::compute::PadOptions;
return std::make_shared<Options>(cpp11::as_cpp<int64_t>(options["width"]),
cpp11::as_cpp<std::string>(options["padding"]));
}
if (func_name == "utf8_split_whitespace" || func_name == "ascii_split_whitespace") {
using Options = arrow::compute::SplitOptions;
int64_t max_splits = -1;
if (!Rf_isNull(options["max_splits"])) {
max_splits = cpp11::as_cpp<int64_t>(options["max_splits"]);
}
bool reverse = false;
if (!Rf_isNull(options["reverse"])) {
reverse = cpp11::as_cpp<bool>(options["reverse"]);
}
return std::make_shared<Options>(max_splits, reverse);
}
if (func_name == "utf8_trim" || func_name == "utf8_ltrim" ||
func_name == "utf8_rtrim" || func_name == "ascii_trim" ||
func_name == "ascii_ltrim" || func_name == "ascii_rtrim") {
using Options = arrow::compute::TrimOptions;
return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["characters"]));
}
if (func_name == "utf8_slice_codeunits" || func_name == "binary_slice") {
using Options = arrow::compute::SliceOptions;
int64_t step = 1;
if (!Rf_isNull(options["step"])) {
step = cpp11::as_cpp<int64_t>(options["step"]);
}
int64_t stop = std::numeric_limits<int32_t>::max();
if (!Rf_isNull(options["stop"])) {
stop = cpp11::as_cpp<int64_t>(options["stop"]);
}
return std::make_shared<Options>(cpp11::as_cpp<int64_t>(options["start"]), stop,
step);
}
if (func_name == "utf8_replace_slice" || func_name == "binary_replace_slice") {
using Options = arrow::compute::ReplaceSliceOptions;
return std::make_shared<Options>(cpp11::as_cpp<int64_t>(options["start"]),
cpp11::as_cpp<int64_t>(options["stop"]),
cpp11::as_cpp<std::string>(options["replacement"]));
}
if (func_name == "variance" || func_name == "stddev" || func_name == "hash_variance" ||
func_name == "hash_stddev") {
using Options = arrow::compute::VarianceOptions;
auto out = std::make_shared<Options>();
out->ddof = cpp11::as_cpp<int64_t>(options["ddof"]);
if (!Rf_isNull(options["min_count"])) {
out->min_count = cpp11::as_cpp<int64_t>(options["min_count"]);
}
if (!Rf_isNull(options["skip_nulls"])) {
out->skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]);
}
return out;
}
if (func_name == "mode") {
using Options = arrow::compute::ModeOptions;
auto out = std::make_shared<Options>(Options::Defaults());
if (!Rf_isNull(options["n"])) {
out->n = cpp11::as_cpp<int64_t>(options["n"]);
}
if (!Rf_isNull(options["min_count"])) {
out->min_count = cpp11::as_cpp<uint32_t>(options["min_count"]);
}
if (!Rf_isNull(options["skip_nulls"])) {
out->skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]);
}
return out;
}
if (func_name == "partition_nth_indices") {
using Options = arrow::compute::PartitionNthOptions;
return std::make_shared<Options>(cpp11::as_cpp<int64_t>(options["pivot"]));
}
if (func_name == "round") {
using Options = arrow::compute::RoundOptions;
auto out = std::make_shared<Options>(Options::Defaults());
if (!Rf_isNull(options["ndigits"])) {
out->ndigits = cpp11::as_cpp<int64_t>(options["ndigits"]);
}
SEXP round_mode = options["round_mode"];
if (!Rf_isNull(round_mode)) {
out->round_mode = cpp11::as_cpp<enum arrow::compute::RoundMode>(round_mode);
}
return out;
}
if (func_name == "round_temporal" || func_name == "floor_temporal" ||
func_name == "ceil_temporal") {
using Options = arrow::compute::RoundTemporalOptions;
int64_t multiple = 1;
enum arrow::compute::CalendarUnit unit = arrow::compute::CalendarUnit::DAY;
bool week_starts_monday = true;
bool ceil_is_strictly_greater = true;
bool calendar_based_origin = true;
if (!Rf_isNull(options["multiple"])) {
multiple = cpp11::as_cpp<int64_t>(options["multiple"]);
}
if (!Rf_isNull(options["unit"])) {
unit = cpp11::as_cpp<enum arrow::compute::CalendarUnit>(options["unit"]);
}
if (!Rf_isNull(options["week_starts_monday"])) {
week_starts_monday = cpp11::as_cpp<bool>(options["week_starts_monday"]);
}
if (!Rf_isNull(options["ceil_is_strictly_greater"])) {
ceil_is_strictly_greater = cpp11::as_cpp<bool>(options["ceil_is_strictly_greater"]);
}
if (!Rf_isNull(options["calendar_based_origin"])) {
calendar_based_origin = cpp11::as_cpp<bool>(options["calendar_based_origin"]);
}
return std::make_shared<Options>(multiple, unit, week_starts_monday,
ceil_is_strictly_greater, calendar_based_origin);
}
if (func_name == "round_to_multiple") {
using Options = arrow::compute::RoundToMultipleOptions;
auto out = std::make_shared<Options>(Options::Defaults());
if (!Rf_isNull(options["multiple"])) {
out->multiple = std::make_shared<arrow::DoubleScalar>(
cpp11::as_cpp<double>(options["multiple"]));
}
SEXP round_mode = options["round_mode"];
if (!Rf_isNull(round_mode)) {
out->round_mode = cpp11::as_cpp<enum arrow::compute::RoundMode>(round_mode);
}
return out;
}
if (func_name == "struct_field") {
using Options = arrow::compute::StructFieldOptions;
if (!Rf_isNull(options["indices"])) {
return std::make_shared<Options>(
cpp11::as_cpp<std::vector<int>>(options["indices"]));
} else {
// field_ref
return std::make_shared<Options>(
*cpp11::as_cpp<std::shared_ptr<arrow::compute::Expression>>(
options["field_ref"])
->field_ref());
}
}
return nullptr;
}
std::shared_ptr<arrow::compute::CastOptions> make_cast_options(cpp11::list options) {
using Options = arrow::compute::CastOptions;
auto out = std::make_shared<Options>(true);
SEXP to_type = options["to_type"];
if (!Rf_isNull(to_type) && cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(to_type)) {
out->to_type = cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(to_type);
}
SEXP allow_float_truncate = options["allow_float_truncate"];
if (!Rf_isNull(allow_float_truncate) && cpp11::as_cpp<bool>(allow_float_truncate)) {
out->allow_float_truncate = cpp11::as_cpp<bool>(allow_float_truncate);
}
SEXP allow_time_truncate = options["allow_time_truncate"];
if (!Rf_isNull(allow_time_truncate) && cpp11::as_cpp<bool>(allow_time_truncate)) {
out->allow_time_truncate = cpp11::as_cpp<bool>(allow_time_truncate);
}
SEXP allow_int_overflow = options["allow_int_overflow"];
if (!Rf_isNull(allow_int_overflow) && cpp11::as_cpp<bool>(allow_int_overflow)) {
out->allow_int_overflow = cpp11::as_cpp<bool>(allow_int_overflow);
}
return out;
}
// [[arrow::export]]
SEXP compute__CallFunction(std::string func_name, cpp11::list args, cpp11::list options) {
auto opts = make_compute_options(func_name, options);
auto datum_args = arrow::r::from_r_list<arrow::Datum>(args);
auto out = ValueOrStop(
arrow::compute::CallFunction(func_name, datum_args, opts.get(), gc_context()));
return from_datum(std::move(out));
}
// [[arrow::export]]
std::vector<std::string> compute__GetFunctionNames() {
return arrow::compute::GetFunctionRegistry()->GetFunctionNames();
}
class RScalarUDFKernelState : public arrow::compute::KernelState {
public:
RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver)
: exec_func_(exec_func), resolver_(resolver) {}
cpp11::sexp exec_func_;
cpp11::sexp resolver_;
};
arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
arrow::compute::KernelContext* context,
const std::vector<arrow::TypeHolder>& input_types) {
return SafeCallIntoR<arrow::TypeHolder>(
[&]() -> arrow::TypeHolder {
auto kernel =
reinterpret_cast<const arrow::compute::ScalarKernel*>(context->kernel());
auto state = std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
cpp11::writable::list input_types_sexp(input_types.size());
for (size_t i = 0; i < input_types.size(); i++) {
input_types_sexp[i] =
cpp11::to_r6<arrow::DataType>(input_types[i].GetSharedPtr());
}
cpp11::sexp output_type_sexp =
cpp11::function(state->resolver_)(input_types_sexp);
if (!Rf_inherits(output_type_sexp, "DataType")) {
cpp11::stop(
"Function specified as arrow_scalar_function() out_type argument must "
"return a DataType");
}
return arrow::TypeHolder(
cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(output_type_sexp));
},
"resolve scalar user-defined function output data type");
}
arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context,
const arrow::compute::ExecSpan& span,
arrow::compute::ExecResult* result) {
if (result->is_array_span()) {
return arrow::Status::NotImplemented("ArraySpan result from R scalar UDF");
}
return SafeCallIntoRVoid(
[&]() {
auto kernel =
reinterpret_cast<const arrow::compute::ScalarKernel*>(context->kernel());
auto state = std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
cpp11::writable::list args_sexp(span.num_values());
for (int i = 0; i < span.num_values(); i++) {
const arrow::compute::ExecValue& exec_val = span[i];
if (exec_val.is_array()) {
args_sexp[i] = cpp11::to_r6<arrow::Array>(exec_val.array.ToArray());
} else if (exec_val.is_scalar()) {
args_sexp[i] = cpp11::to_r6<arrow::Scalar>(exec_val.scalar->GetSharedPtr());
}
}
cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length);
std::shared_ptr<arrow::DataType> output_type = result->type()->GetSharedPtr();
cpp11::sexp output_type_sexp = cpp11::to_r6<arrow::DataType>(output_type);
cpp11::writable::list udf_context = {batch_length_sexp, output_type_sexp};
udf_context.names() = {"batch_length", "output_type"};
cpp11::sexp func_result_sexp =
cpp11::function(state->exec_func_)(udf_context, args_sexp);
if (Rf_inherits(func_result_sexp, "Array")) {
auto array = cpp11::as_cpp<std::shared_ptr<arrow::Array>>(func_result_sexp);
// Error for an Array result of the wrong type
if (!result->type()->Equals(array->type())) {
return cpp11::stop(
"Expected return Array or Scalar with type '%s' from user-defined "
"function but got Array with type '%s'",
result->type()->ToString().c_str(), array->type()->ToString().c_str());
}
result->value = std::move(array->data());
} else if (Rf_inherits(func_result_sexp, "Scalar")) {
auto scalar = cpp11::as_cpp<std::shared_ptr<arrow::Scalar>>(func_result_sexp);
// handle a Scalar result of the wrong type
if (!result->type()->Equals(scalar->type)) {
return cpp11::stop(
"Expected return Array or Scalar with type '%s' from user-defined "
"function but got Scalar with type '%s'",
result->type()->ToString().c_str(), scalar->type->ToString().c_str());
}
auto array = ValueOrStop(
arrow::MakeArrayFromScalar(*scalar, span.length, context->memory_pool()));
result->value = std::move(array->data());
} else {
cpp11::stop("arrow_scalar_function must return an Array or Scalar");
}
},
"execute scalar user-defined function");
}
// [[arrow::export]]
void RegisterScalarUDF(std::string name, cpp11::list func_sexp) {
cpp11::list in_type_r(func_sexp["in_type"]);
cpp11::list out_type_r(func_sexp["out_type"]);
R_xlen_t n_kernels = in_type_r.size();
if (n_kernels == 0) {
cpp11::stop("Can't register user-defined function with zero kernels");
}
// Compute the Arity from the list of input kernels. We don't currently handle
// variable numbers of arguments in a user-defined function.
int64_t n_args =
cpp11::as_cpp<std::shared_ptr<arrow::Schema>>(in_type_r[0])->num_fields();
for (R_xlen_t i = 1; i < n_kernels; i++) {
auto in_types = cpp11::as_cpp<std::shared_ptr<arrow::Schema>>(in_type_r[i]);
if (in_types->num_fields() != n_args) {
cpp11::stop(
"Kernels for user-defined function must accept the same number of arguments");
}
}
arrow::compute::Arity arity(n_args, false);
// The function documentation isn't currently accessible from R but is required
// for the C++ function constructor.
std::vector<std::string> dummy_argument_names(n_args);
for (int64_t i = 0; i < n_args; i++) {
dummy_argument_names[i] = "arg";
}
const arrow::compute::FunctionDoc dummy_function_doc{
"A user-defined R function", "returns something", std::move(dummy_argument_names)};
auto func =
std::make_shared<arrow::compute::ScalarFunction>(name, arity, dummy_function_doc);
for (R_xlen_t i = 0; i < n_kernels; i++) {
auto in_types = cpp11::as_cpp<std::shared_ptr<arrow::Schema>>(in_type_r[i]);
cpp11::sexp out_type_func = out_type_r[i];
std::vector<arrow::compute::InputType> compute_in_types(in_types->num_fields());
for (int64_t j = 0; j < in_types->num_fields(); j++) {
compute_in_types[j] = arrow::compute::InputType(in_types->field(j)->type());
}
arrow::compute::OutputType out_type((&ResolveScalarUDFOutputType));
auto signature = std::make_shared<arrow::compute::KernelSignature>(
std::move(compute_in_types), std::move(out_type), true);
arrow::compute::ScalarKernel kernel(signature, &CallRScalarUDF);
kernel.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE;
kernel.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
kernel.data =
std::make_shared<RScalarUDFKernelState>(func_sexp["wrapper_fun"], out_type_func);
StopIfNotOk(func->AddKernel(std::move(kernel)));
}
auto registry = arrow::compute::GetFunctionRegistry();
StopIfNotOk(registry->AddFunction(std::move(func), true));
}