blob: 1a76ada1995c92e7e5e2b13a6e674819899e9e10 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <cmath>
#include <functional> /* for std::multiplies */
#include <numeric> /* for std::accumulate */
#include "feather_writer.h"
#include <arrow/array.h>
#include <arrow/buffer.h>
#include <arrow/io/file.h>
#include <arrow/ipc/feather.h>
#include <arrow/status.h>
#include <arrow/table.h>
#include <arrow/type.h>
#include <arrow/util/bit_util.h>
#include <arrow/util/bitmap_generate.h>
#include <arrow/util/key_value_metadata.h>
#include <mex.h>
#include "matlab_traits.h"
#include "util/handle_status.h"
namespace arrow {
namespace matlab {
namespace internal {
// Returns the arrow::DataType that corresponds to the input type string
std::shared_ptr<arrow::DataType> ConvertMatlabTypeStringToArrowDataType(
const std::string& t) {
if (t == "double") {
return arrow::float64();
} else if (t == "single") {
return arrow::float32();
} else if (t == "uint64") {
return arrow::uint64();
} else if (t == "uint32") {
return arrow::uint32();
} else if (t == "uint16") {
return arrow::uint16();
} else if (t == "uint8") {
return arrow::uint8();
} else if (t == "int64") {
return arrow::int64();
} else if (t == "int32") {
return arrow::int32();
} else if (t == "int16") {
return arrow::int16();
} else if (t == "int8") {
return arrow::int8();
}
mexErrMsgIdAndTxt("MATLAB:arrow:UnsupportedMatlabTypeString",
"Unsupported MATLAB type string: '%s'", t.c_str());
// mexErrMsgIdAndTxt throws unconditionally so we should never reach this line
return nullptr;
}
// Utility that helps verify the input mxArray struct field name and type.
// Returns void since any errors will throw and terminate MEX execution.
void ValidateMxStructField(const mxArray* struct_array, const char* fieldname,
mxClassID expected_class_id, bool can_be_empty) {
// Check that the input mxArray is a struct array.
if (!mxIsStruct(struct_array)) {
mexErrMsgIdAndTxt("MATLAB:arrow:IncorrectDimensionsOrType",
"Input needs to be a struct array");
}
// Return early if an empty table is provided as input.
if (mxIsEmpty(struct_array)) {
return;
}
mxArray* field = mxGetField(struct_array, 0, fieldname);
if (!field) {
mexErrMsgIdAndTxt("MATLAB:arrow:MissingStructField",
"Missing field '%s' in input struct array", fieldname);
}
mxClassID actual_class_id = mxGetClassID(field);
// Avoid type check if an mxUNKNOWN_CLASS is provided since the UNKNOWN type is used to
// signify genericity in the input type.
if (expected_class_id != mxUNKNOWN_CLASS) {
if (expected_class_id != actual_class_id) {
mexErrMsgIdAndTxt("MATLAB:arrow:MissingStructField",
"Incorrect type '%s' for struct array field '%s'",
mxGetClassName(field), fieldname);
}
}
// Some struct fields (like Data) can be empty, while others
// (like NumRows) should never be empty. This conditional helps account for both cases.
if (!can_be_empty) {
// Ensure that individual mxStructArray fields are non-empty.
// We can call mxGetData after this without needing another null check.
if (mxIsEmpty(field)) {
mexErrMsgIdAndTxt("MATLAB:arrow:EmptyStructField",
"Struct array field '%s' cannot be empty", fieldname);
}
}
}
// Utility function to convert mxChar mxArray* to std::string while preserving
// Unicode code points.
std::string MxArrayToString(const mxArray* array) {
// Return empty std::string if a mxChar array is not passed in.
if (!mxIsChar(array)) {
return std::string();
}
// Convert mxArray first to a C-style char array, then copy into a std::string.
char* utf8_array = mxArrayToUTF8String(array);
std::string output(utf8_array);
// Free the allocated char* from the MEX runtime.
mxFree(utf8_array);
return output;
}
// Compare number of columns and exit out to the MATLAB layer if incorrect.
void ValidateNumColumns(int64_t actual, int64_t expected) {
if (actual != expected) {
mexErrMsgIdAndTxt("MATLAB:arrow:IncorrectNumberOfColumns",
"Received only '%d' columns but expected '%d' columns", actual,
expected);
}
}
// Compare number of rows and exit out to the MATLAB layer if incorrect.
void ValidateNumRows(int64_t actual, int64_t expected) {
if (actual != expected) {
mexErrMsgIdAndTxt("MATLAB:arrow:IncorrectNumberOfRows",
"Received only '%d' rows but expected '%d' rows", actual, expected);
}
}
// Calculate the number of bytes required in the bit-packed validity buffer.
int64_t BitPackedLength(int64_t num_elements) {
// Since mxLogicalArray encodes [0, 1] in a full byte, we can compress that byte
// down to a bit...therefore dividing the mxLogicalArray length by 8 here.
return static_cast<int64_t>(std::ceil(num_elements / 8.0));
}
// Calculate the total number of elements in an mxArray
// We have to do this separately since mxGetNumberOfElements only works in numeric arrays
size_t GetNumberOfElements(const mxArray* array) {
// Get the dimensions and the total number of dimensions from the mxArray*.
const size_t num_dimensions = mxGetNumberOfDimensions(array);
const size_t* dimensions = mxGetDimensions(array);
// Iterate over the dimensions array and accumulate the total number of elements.
return std::accumulate(dimensions, dimensions + num_dimensions, size_t{1},
std::multiplies<size_t>());
}
// Write an mxLogicalArray* into a bit-packed arrow::MutableBuffer
void BitPackBuffer(const mxArray* logical_array,
std::shared_ptr<MutableBuffer> packed_buffer) {
// Error out if the incorrect type is passed in.
if (!mxIsLogical(logical_array)) {
mexErrMsgIdAndTxt(
"MATLAB:arrow:IncorrectType",
"Expected mxLogical array as input but received mxArray of class '%s'",
mxGetClassName(logical_array));
}
// Validate that the input arrow::Buffer has sufficient size to store a full bit-packed
// representation of the input mxLogicalArray
int64_t unpacked_buffer_length = GetNumberOfElements(logical_array);
if (BitPackedLength(unpacked_buffer_length) > packed_buffer->capacity()) {
mexErrMsgIdAndTxt("MATLAB:arrow:BufferSizeExceeded",
"Buffer of size %d bytes cannot store %d bytes of data",
packed_buffer->capacity(), BitPackedLength(unpacked_buffer_length));
}
// Get pointers to the internal uint8_t arrays behind arrow::Buffer and mxArray
uint8_t* packed_buffer_ptr = packed_buffer->mutable_data();
mxLogical* unpacked_buffer_ptr = mxGetLogicals(logical_array);
// Iterate over the mxLogical array and write bit-packed bools to the arrow::Buffer.
// Call into a loop-unrolled Arrow utility for better performance when bit-packing.
auto generator = [&]() -> bool { return *(unpacked_buffer_ptr++); };
const int64_t start_offset = 0;
arrow::internal::GenerateBitsUnrolled(packed_buffer_ptr, start_offset,
unpacked_buffer_length, generator);
}
// Write numeric datatypes to the Feather file.
template <typename ArrowDataType>
std::unique_ptr<Array> WriteNumericData(const mxArray* data,
const std::shared_ptr<Buffer> validity_bitmap) {
// Alias the type name for the underlying MATLAB type.
using MatlabType = typename MatlabTraits<ArrowDataType>::MatlabType;
// Get a pointer to the underlying mxArray data.
// We need to (temporarily) cast away const here since the mxGet* functions do not
// accept a const input parameter for compatibility reasons.
const MatlabType* dt = MatlabTraits<ArrowDataType>::GetData(const_cast<mxArray*>(data));
// Construct an arrow::Buffer that points to the underlying mxArray without copying.
// - The lifetime of the mxArray buffer exceeds that of the arrow::Buffer here since
// MATLAB should only free this region on garbage-collection after the MEX function
// is executed. Therefore it is safe for arrow::Buffer to point to this location.
// - However arrow::Buffer must not free this region by itself, since that could cause
// segfaults if the input array is used later in MATLAB.
// - The Doxygen doc for arrow::Buffer's constructor implies that it is not an RAII
// type, so this should be safe from possible double-free here.
std::shared_ptr<Buffer> buffer =
std::make_shared<Buffer>(reinterpret_cast<const uint8_t*>(dt),
mxGetElementSize(data) * mxGetNumberOfElements(data));
// Construct arrow::NumericArray specialization using arrow::Buffer.
// Pass in nulls information...we could compute and provide the number of nulls here
// too, but passing -1 for now so that Arrow recomputes it if necessary.
return std::unique_ptr<Array>(new NumericArray<ArrowDataType>(
mxGetNumberOfElements(data), buffer, validity_bitmap, -1));
}
// Dispatch MATLAB column data to the correct arrow::Array converter.
std::unique_ptr<Array> WriteVariableData(const mxArray* data, const std::string& type,
const std::shared_ptr<Buffer> validity_bitmap) {
// Get the underlying type of the mxArray data.
const mxClassID mxclass = mxGetClassID(data);
switch (mxclass) {
case mxSINGLE_CLASS:
return WriteNumericData<FloatType>(data, validity_bitmap);
case mxDOUBLE_CLASS:
return WriteNumericData<DoubleType>(data, validity_bitmap);
case mxUINT8_CLASS:
return WriteNumericData<UInt8Type>(data, validity_bitmap);
case mxUINT16_CLASS:
return WriteNumericData<UInt16Type>(data, validity_bitmap);
case mxUINT32_CLASS:
return WriteNumericData<UInt32Type>(data, validity_bitmap);
case mxUINT64_CLASS:
return WriteNumericData<UInt64Type>(data, validity_bitmap);
case mxINT8_CLASS:
return WriteNumericData<Int8Type>(data, validity_bitmap);
case mxINT16_CLASS:
return WriteNumericData<Int16Type>(data, validity_bitmap);
case mxINT32_CLASS:
return WriteNumericData<Int32Type>(data, validity_bitmap);
case mxINT64_CLASS:
return WriteNumericData<Int64Type>(data, validity_bitmap);
default: {
mexErrMsgIdAndTxt("MATLAB:arrow:UnsupportedArrowType",
"Unsupported arrow::Type '%s' for variable '%s'",
mxGetClassName(data), type.c_str());
}
}
// We shouldn't ever reach this branch, but if we do, return nullptr.
return nullptr;
}
} // namespace internal
Status FeatherWriter::Open(const std::string& filename,
std::shared_ptr<FeatherWriter>* feather_writer) {
// Allocate shared_ptr out parameter.
*feather_writer = std::shared_ptr<FeatherWriter>(new FeatherWriter());
// Open a FileOutputStream corresponding to the provided filename.
ARROW_ASSIGN_OR_RAISE((*feather_writer)->file_output_stream_,
io::FileOutputStream::Open(filename, &((*feather_writer)->file_output_stream_)));
return Status::OK();
}
// Write mxArrays from MATLAB into a Feather file.
Status FeatherWriter::WriteVariables(const mxArray* variables, const mxArray* metadata) {
// Verify that all required fieldnames are provided.
internal::ValidateMxStructField(variables, "Name", mxCHAR_CLASS, true);
internal::ValidateMxStructField(variables, "Type", mxCHAR_CLASS, false);
internal::ValidateMxStructField(variables, "Data", mxUNKNOWN_CLASS, true);
internal::ValidateMxStructField(variables, "Valid", mxLOGICAL_CLASS, true);
// Verify that all required fieldnames are provided.
internal::ValidateMxStructField(metadata, "NumRows", mxDOUBLE_CLASS, false);
internal::ValidateMxStructField(metadata, "NumVariables", mxDOUBLE_CLASS, false);
// Get the number of columns in the struct array.
size_t num_columns = internal::GetNumberOfElements(variables);
// Get the NumRows field in the struct array and set on TableWriter.
num_rows_ = static_cast<int64_t>(mxGetScalar(mxGetField(metadata, 0, "NumRows")));
// Get the total number of variables. This is checked later for consistency with
// the provided number of columns before finishing the file write.
num_variables_ =
static_cast<int64_t>(mxGetScalar(mxGetField(metadata, 0, "NumVariables")));
// Verify that we have all the columns required for writing
// Currently we need all columns to be passed in together in the WriteVariables method.
internal::ValidateNumColumns(static_cast<int64_t>(num_columns), num_variables_);
arrow::SchemaBuilder schema_builder;
std::vector<std::shared_ptr<arrow::Array>> table_columns;
const int64_t bitpacked_length = internal::BitPackedLength(num_rows_);
// Iterate over the input columns and generate arrow arrays.
for (int idx = 0; idx < num_columns; ++idx) {
// Unwrap constituent mxArray*s from the mxStructArray*. This is safe since we
// already checked for existence and non-nullness of these types.
const mxArray* name = mxGetField(variables, idx, "Name");
const mxArray* data = mxGetField(variables, idx, "Data");
const mxArray* type = mxGetField(variables, idx, "Type");
const mxArray* valid = mxGetField(variables, idx, "Valid");
// Convert column and type name to a std::string from mxArray*.
std::string name_str = internal::MxArrayToString(name);
std::string type_str = internal::MxArrayToString(type);
auto datatype = internal::ConvertMatlabTypeStringToArrowDataType(type_str);
auto field = std::make_shared<arrow::Field>(name_str, datatype);
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ResizableBuffer> validity_bitmap,
arrow::AllocateResizableBuffer(internal::BitPackedLength(num_rows_)));
// Populate bit-packed arrow::Buffer using validity data in the mxArray*.
internal::BitPackBuffer(valid, validity_bitmap);
// Wrap mxArray data in an arrow::Array of the equivalent type.
auto array =
internal::WriteVariableData(data, type_str, validity_bitmap);
// Verify that the arrow::Array has the right number of elements.
internal::ValidateNumRows(array->length(), num_rows_);
// Append the field to the schema builder
RETURN_NOT_OK(schema_builder.AddField(field));
// Store the table column
table_columns.push_back(std::move(array));
}
// Create the table schema
ARROW_ASSIGN_OR_RAISE(auto table_schema, schema_builder.Finish());
// Specify the feather file format version as V1
arrow::ipc::feather::WriteProperties write_props;
write_props.version = arrow::ipc::feather::kFeatherV1Version;
std::shared_ptr<arrow::Table> table = arrow::Table::Make(table_schema, table_columns);
// Write the Feather file metadata to the end of the file.
return ipc::feather::WriteTable(*table, file_output_stream_.get(), write_props);
}
} // namespace matlab
} // namespace arrow