blob: a13b252b2832f12dab41a5531b582be90b6c2b11 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "./arrow_types.h"
#include <thread>
#include <arrow/array.h>
#include <arrow/extension_type.h>
#include <arrow/type.h>
#include "./extension.h"
#include "./safe-call-into-r.h"
bool RExtensionType::ExtensionEquals(const arrow::ExtensionType& other) const {
// Avoid materializing the R6 instance if at all possible
if (other.extension_name() != extension_name()) {
return false;
}
if (other.Serialize() == Serialize()) {
return true;
}
// With any ambiguity, we need to materialize the R6 instance and call its
// ExtensionEquals method. We can't do this on the non-R thread.
arrow::Result<bool> result = SafeCallIntoR<bool>(
[&]() {
cpp11::environment instance = r6_instance();
cpp11::function instance_ExtensionEquals(instance["ExtensionEquals"]);
std::shared_ptr<DataType> other_shared =
ValueOrStop(other.Deserialize(other.storage_type(), other.Serialize()));
cpp11::sexp other_r6 = cpp11::to_r6<DataType>(other_shared, "ExtensionType");
cpp11::logicals result(instance_ExtensionEquals(other_r6));
return cpp11::as_cpp<bool>(result);
},
"RExtensionType$ExtensionEquals()");
if (!result.ok()) {
throw std::runtime_error(result.status().message());
}
return result.ValueUnsafe();
}
std::shared_ptr<arrow::Array> RExtensionType::MakeArray(
std::shared_ptr<arrow::ArrayData> data) const {
std::shared_ptr<arrow::ArrayData> new_data = data->Copy();
std::unique_ptr<RExtensionType> cloned = Clone();
new_data->type = std::shared_ptr<RExtensionType>(cloned.release());
return std::make_shared<arrow::ExtensionArray>(new_data);
}
arrow::Result<std::shared_ptr<arrow::DataType>> RExtensionType::Deserialize(
std::shared_ptr<arrow::DataType> storage_type,
const std::string& serialized_data) const {
std::unique_ptr<RExtensionType> cloned = Clone();
cloned->storage_type_ = storage_type;
cloned->extension_metadata_ = serialized_data;
// We could create an ephemeral R6 instance here, which will call the R6 instance's
// deserialize_instance() method, possibly erroring when the metadata is
// invalid or the deserialized values are invalid. The complexity of setting up
// an event loop from wherever this *might* be called is high and hard to
// predict. As a compromise, just create the instance when it is safe to
// do so.
if (MainRThread::GetInstance().IsMainThread()) {
r6_instance();
}
return std::shared_ptr<RExtensionType>(cloned.release());
}
std::string RExtensionType::ToString() const {
arrow::Result<std::string> result = SafeCallIntoR<std::string>([&]() {
cpp11::environment instance = r6_instance();
cpp11::function instance_ToString(instance["ToString"]);
cpp11::sexp result = instance_ToString();
return cpp11::as_cpp<std::string>(result);
});
// In the event of an error (e.g., we are not on the main thread
// and we are not inside RunWithCapturedR()), just call the default method
if (!result.ok()) {
return ExtensionType::ToString();
} else {
return result.ValueUnsafe();
}
}
cpp11::sexp RExtensionType::Convert(
const std::shared_ptr<arrow::ChunkedArray>& array) const {
cpp11::environment instance = r6_instance();
cpp11::function instance_Convert(instance["as_vector"]);
cpp11::sexp array_sexp = cpp11::to_r6<arrow::ChunkedArray>(array, "ChunkedArray");
return instance_Convert(array_sexp);
}
std::unique_ptr<RExtensionType> RExtensionType::Clone() const {
RExtensionType* ptr =
new RExtensionType(storage_type(), extension_name_, extension_metadata_, r6_class_);
return std::unique_ptr<RExtensionType>(ptr);
}
cpp11::environment RExtensionType::r6_instance(
std::shared_ptr<arrow::DataType> storage_type,
const std::string& serialized_data) const {
// This is a version of to_r6<>() that is a more direct route to creating the object.
// This is done to avoid circular calls, since to_r6<>() has to go through
// ExtensionType$new(), which then calls back to C++ to get r6_class_ to then
// return the correct subclass.
std::unique_ptr<RExtensionType> cloned = Clone();
cpp11::external_pointer<std::shared_ptr<RExtensionType>> xp(
new std::shared_ptr<RExtensionType>(cloned.release()));
cpp11::function r6_class_new(r6_class()["new"]);
return r6_class_new(xp);
}
// [[arrow::export]]
cpp11::environment ExtensionType__initialize(
const std::shared_ptr<arrow::DataType>& storage_type, std::string extension_name,
cpp11::raws extension_metadata, cpp11::environment r6_class) {
std::string metadata_string(extension_metadata.begin(), extension_metadata.end());
auto r6_class_shared = std::make_shared<cpp11::environment>(r6_class);
RExtensionType cpp_type(storage_type, extension_name, metadata_string, r6_class_shared);
return cpp_type.r6_instance();
}
// [[arrow::export]]
std::string ExtensionType__extension_name(
const std::shared_ptr<arrow::ExtensionType>& type) {
return type->extension_name();
}
// [[arrow::export]]
cpp11::raws ExtensionType__Serialize(const std::shared_ptr<arrow::ExtensionType>& type) {
std::string serialized_string = type->Serialize();
cpp11::writable::raws bytes(serialized_string.begin(), serialized_string.end());
return bytes;
}
// [[arrow::export]]
std::shared_ptr<arrow::DataType> ExtensionType__storage_type(
const std::shared_ptr<arrow::ExtensionType>& type) {
return type->storage_type();
}
// [[arrow::export]]
std::shared_ptr<arrow::Array> ExtensionType__MakeArray(
const std::shared_ptr<arrow::ExtensionType>& type,
const std::shared_ptr<arrow::ArrayData>& data) {
return type->MakeArray(data);
}
// [[arrow::export]]
cpp11::environment ExtensionType__r6_class(
const std::shared_ptr<arrow::ExtensionType>& type) {
auto r_type =
arrow::internal::checked_pointer_cast<RExtensionType, arrow::ExtensionType>(type);
return r_type->r6_class();
}
// [[arrow::export]]
std::shared_ptr<arrow::Array> ExtensionArray__storage(
const std::shared_ptr<arrow::ExtensionArray>& array) {
return array->storage();
}
// [[arrow::export]]
void arrow__RegisterRExtensionType(const std::shared_ptr<arrow::DataType>& type) {
auto ext_type = std::dynamic_pointer_cast<arrow::ExtensionType>(type);
StopIfNotOk(arrow::RegisterExtensionType(ext_type));
}
// [[arrow::export]]
void arrow__UnregisterRExtensionType(std::string type_name) {
StopIfNotOk(arrow::UnregisterExtensionType(type_name));
}