| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| //! Arrow IPC File and Stream Writers |
| //! |
| //! The `FileWriter` and `StreamWriter` have similar interfaces, |
| //! however the `FileWriter` expects a reader that supports `Seek`ing |
| |
| use std::cmp::min; |
| use std::collections::HashMap; |
| use std::io::{BufWriter, Write}; |
| use std::sync::Arc; |
| |
| use flatbuffers::FlatBufferBuilder; |
| |
| use arrow_array::builder::BufferBuilder; |
| use arrow_array::cast::*; |
| use arrow_array::types::{Int16Type, Int32Type, Int64Type, RunEndIndexType}; |
| use arrow_array::*; |
| use arrow_buffer::bit_util; |
| use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; |
| use arrow_data::{layout, ArrayData, ArrayDataBuilder, BufferSpec}; |
| use arrow_schema::*; |
| |
| use crate::compression::CompressionCodec; |
| use crate::CONTINUATION_MARKER; |
| |
| /// IPC write options used to control the behaviour of the [`IpcDataGenerator`] |
| #[derive(Debug, Clone)] |
| pub struct IpcWriteOptions { |
| /// Write padding after memory buffers to this multiple of bytes. |
| /// Must be 8, 16, 32, or 64 - defaults to 64. |
| alignment: u8, |
| /// The legacy format is for releases before 0.15.0, and uses metadata V4 |
| write_legacy_ipc_format: bool, |
| /// The metadata version to write. The Rust IPC writer supports V4+ |
| /// |
| /// *Default versions per crate* |
| /// |
| /// When creating the default IpcWriteOptions, the following metadata versions are used: |
| /// |
| /// version 2.0.0: V4, with legacy format enabled |
| /// version 4.0.0: V5 |
| metadata_version: crate::MetadataVersion, |
| /// Compression, if desired. Will result in a runtime error |
| /// if the corresponding feature is not enabled |
| batch_compression_type: Option<crate::CompressionType>, |
| /// Flag indicating whether the writer should preserver the dictionary IDs defined in the |
| /// schema or generate unique dictionary IDs internally during encoding. |
| /// |
| /// Defaults to `true` |
| preserve_dict_id: bool, |
| } |
| |
| impl IpcWriteOptions { |
| /// Configures compression when writing IPC files. |
| /// |
| /// Will result in a runtime error if the corresponding feature |
| /// is not enabled |
| pub fn try_with_compression( |
| mut self, |
| batch_compression_type: Option<crate::CompressionType>, |
| ) -> Result<Self, ArrowError> { |
| self.batch_compression_type = batch_compression_type; |
| |
| if self.batch_compression_type.is_some() |
| && self.metadata_version < crate::MetadataVersion::V5 |
| { |
| return Err(ArrowError::InvalidArgumentError( |
| "Compression only supported in metadata v5 and above".to_string(), |
| )); |
| } |
| Ok(self) |
| } |
| /// Try to create IpcWriteOptions, checking for incompatible settings |
| pub fn try_new( |
| alignment: usize, |
| write_legacy_ipc_format: bool, |
| metadata_version: crate::MetadataVersion, |
| ) -> Result<Self, ArrowError> { |
| let is_alignment_valid = |
| alignment == 8 || alignment == 16 || alignment == 32 || alignment == 64; |
| if !is_alignment_valid { |
| return Err(ArrowError::InvalidArgumentError( |
| "Alignment should be 8, 16, 32, or 64.".to_string(), |
| )); |
| } |
| let alignment: u8 = u8::try_from(alignment).expect("range already checked"); |
| match metadata_version { |
| crate::MetadataVersion::V1 |
| | crate::MetadataVersion::V2 |
| | crate::MetadataVersion::V3 => Err(ArrowError::InvalidArgumentError( |
| "Writing IPC metadata version 3 and lower not supported".to_string(), |
| )), |
| crate::MetadataVersion::V4 => Ok(Self { |
| alignment, |
| write_legacy_ipc_format, |
| metadata_version, |
| batch_compression_type: None, |
| preserve_dict_id: true, |
| }), |
| crate::MetadataVersion::V5 => { |
| if write_legacy_ipc_format { |
| Err(ArrowError::InvalidArgumentError( |
| "Legacy IPC format only supported on metadata version 4".to_string(), |
| )) |
| } else { |
| Ok(Self { |
| alignment, |
| write_legacy_ipc_format, |
| metadata_version, |
| batch_compression_type: None, |
| preserve_dict_id: true, |
| }) |
| } |
| } |
| z => Err(ArrowError::InvalidArgumentError(format!( |
| "Unsupported crate::MetadataVersion {z:?}" |
| ))), |
| } |
| } |
| |
| pub fn preserve_dict_id(&self) -> bool { |
| self.preserve_dict_id |
| } |
| |
| /// Set whether the IPC writer should preserve the dictionary IDs in the schema |
| /// or auto-assign unique dictionary IDs during encoding (defaults to true) |
| /// |
| /// If this option is true, the application must handle assigning ids |
| /// to the dictionary batches in order to encode them correctly |
| /// |
| /// The default will change to `false` in future releases |
| pub fn with_preserve_dict_id(mut self, preserve_dict_id: bool) -> Self { |
| self.preserve_dict_id = preserve_dict_id; |
| self |
| } |
| } |
| |
| impl Default for IpcWriteOptions { |
| fn default() -> Self { |
| Self { |
| alignment: 64, |
| write_legacy_ipc_format: false, |
| metadata_version: crate::MetadataVersion::V5, |
| batch_compression_type: None, |
| preserve_dict_id: true, |
| } |
| } |
| } |
| |
| #[derive(Debug, Default)] |
| /// Handles low level details of encoding [`Array`] and [`Schema`] into the |
| /// [Arrow IPC Format]. |
| /// |
| /// # Example: |
| /// ``` |
| /// # fn run() { |
| /// # use std::sync::Arc; |
| /// # use arrow_array::UInt64Array; |
| /// # use arrow_array::RecordBatch; |
| /// # use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; |
| /// |
| /// // Create a record batch |
| /// let batch = RecordBatch::try_from_iter(vec![ |
| /// ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _) |
| /// ]).unwrap(); |
| /// |
| /// // Error of dictionary ids are replaced. |
| /// let error_on_replacement = true; |
| /// let options = IpcWriteOptions::default(); |
| /// let mut dictionary_tracker = DictionaryTracker::new(error_on_replacement); |
| /// |
| /// // encode the batch into zero or more encoded dictionaries |
| /// // and the data for the actual array. |
| /// let data_gen = IpcDataGenerator::default(); |
| /// let (encoded_dictionaries, encoded_message) = data_gen |
| /// .encoded_batch(&batch, &mut dictionary_tracker, &options) |
| /// .unwrap(); |
| /// # } |
| /// ``` |
| /// |
| /// [Arrow IPC Format]: https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc |
| |
| pub struct IpcDataGenerator {} |
| |
| impl IpcDataGenerator { |
| pub fn schema_to_bytes(&self, schema: &Schema, write_options: &IpcWriteOptions) -> EncodedData { |
| let mut fbb = FlatBufferBuilder::new(); |
| let schema = { |
| let fb = crate::convert::schema_to_fb_offset(&mut fbb, schema); |
| fb.as_union_value() |
| }; |
| |
| let mut message = crate::MessageBuilder::new(&mut fbb); |
| message.add_version(write_options.metadata_version); |
| message.add_header_type(crate::MessageHeader::Schema); |
| message.add_bodyLength(0); |
| message.add_header(schema); |
| // TODO: custom metadata |
| let data = message.finish(); |
| fbb.finish(data, None); |
| |
| let data = fbb.finished_data(); |
| EncodedData { |
| ipc_message: data.to_vec(), |
| arrow_data: vec![], |
| } |
| } |
| |
| fn _encode_dictionaries<I: Iterator<Item = i64>>( |
| &self, |
| column: &ArrayRef, |
| encoded_dictionaries: &mut Vec<EncodedData>, |
| dictionary_tracker: &mut DictionaryTracker, |
| write_options: &IpcWriteOptions, |
| dict_id: &mut I, |
| ) -> Result<(), ArrowError> { |
| match column.data_type() { |
| DataType::Struct(fields) => { |
| let s = as_struct_array(column); |
| for (field, column) in fields.iter().zip(s.columns()) { |
| self.encode_dictionaries( |
| field, |
| column, |
| encoded_dictionaries, |
| dictionary_tracker, |
| write_options, |
| dict_id, |
| )?; |
| } |
| } |
| DataType::RunEndEncoded(_, values) => { |
| let data = column.to_data(); |
| if data.child_data().len() != 2 { |
| return Err(ArrowError::InvalidArgumentError(format!( |
| "The run encoded array should have exactly two child arrays. Found {}", |
| data.child_data().len() |
| ))); |
| } |
| // The run_ends array is not expected to be dictionary encoded. Hence encode dictionaries |
| // only for values array. |
| let values_array = make_array(data.child_data()[1].clone()); |
| self.encode_dictionaries( |
| values, |
| &values_array, |
| encoded_dictionaries, |
| dictionary_tracker, |
| write_options, |
| dict_id, |
| )?; |
| } |
| DataType::List(field) => { |
| let list = as_list_array(column); |
| self.encode_dictionaries( |
| field, |
| list.values(), |
| encoded_dictionaries, |
| dictionary_tracker, |
| write_options, |
| dict_id, |
| )?; |
| } |
| DataType::LargeList(field) => { |
| let list = as_large_list_array(column); |
| self.encode_dictionaries( |
| field, |
| list.values(), |
| encoded_dictionaries, |
| dictionary_tracker, |
| write_options, |
| dict_id, |
| )?; |
| } |
| DataType::FixedSizeList(field, _) => { |
| let list = column |
| .as_any() |
| .downcast_ref::<FixedSizeListArray>() |
| .expect("Unable to downcast to fixed size list array"); |
| self.encode_dictionaries( |
| field, |
| list.values(), |
| encoded_dictionaries, |
| dictionary_tracker, |
| write_options, |
| dict_id, |
| )?; |
| } |
| DataType::Map(field, _) => { |
| let map_array = as_map_array(column); |
| |
| let (keys, values) = match field.data_type() { |
| DataType::Struct(fields) if fields.len() == 2 => (&fields[0], &fields[1]), |
| _ => panic!("Incorrect field data type {:?}", field.data_type()), |
| }; |
| |
| // keys |
| self.encode_dictionaries( |
| keys, |
| map_array.keys(), |
| encoded_dictionaries, |
| dictionary_tracker, |
| write_options, |
| dict_id, |
| )?; |
| |
| // values |
| self.encode_dictionaries( |
| values, |
| map_array.values(), |
| encoded_dictionaries, |
| dictionary_tracker, |
| write_options, |
| dict_id, |
| )?; |
| } |
| DataType::Union(fields, _) => { |
| let union = as_union_array(column); |
| for (type_id, field) in fields.iter() { |
| let column = union.child(type_id); |
| self.encode_dictionaries( |
| field, |
| column, |
| encoded_dictionaries, |
| dictionary_tracker, |
| write_options, |
| dict_id, |
| )?; |
| } |
| } |
| _ => (), |
| } |
| |
| Ok(()) |
| } |
| |
| fn encode_dictionaries<I: Iterator<Item = i64>>( |
| &self, |
| field: &Field, |
| column: &ArrayRef, |
| encoded_dictionaries: &mut Vec<EncodedData>, |
| dictionary_tracker: &mut DictionaryTracker, |
| write_options: &IpcWriteOptions, |
| dict_id_seq: &mut I, |
| ) -> Result<(), ArrowError> { |
| match column.data_type() { |
| DataType::Dictionary(_key_type, _value_type) => { |
| let dict_id = dict_id_seq |
| .next() |
| .or_else(|| field.dict_id()) |
| .ok_or_else(|| { |
| ArrowError::IpcError(format!("no dict id for field {}", field.name())) |
| })?; |
| |
| let dict_data = column.to_data(); |
| let dict_values = &dict_data.child_data()[0]; |
| |
| let values = make_array(dict_data.child_data()[0].clone()); |
| |
| self._encode_dictionaries( |
| &values, |
| encoded_dictionaries, |
| dictionary_tracker, |
| write_options, |
| dict_id_seq, |
| )?; |
| |
| let emit = dictionary_tracker.insert(dict_id, column)?; |
| |
| if emit { |
| encoded_dictionaries.push(self.dictionary_batch_to_bytes( |
| dict_id, |
| dict_values, |
| write_options, |
| )?); |
| } |
| } |
| _ => self._encode_dictionaries( |
| column, |
| encoded_dictionaries, |
| dictionary_tracker, |
| write_options, |
| dict_id_seq, |
| )?, |
| } |
| |
| Ok(()) |
| } |
| |
| /// Encodes a batch to a number of [EncodedData] items (dictionary batches + the record batch). |
| /// The [DictionaryTracker] keeps track of dictionaries with new `dict_id`s (so they are only sent once) |
| /// Make sure the [DictionaryTracker] is initialized at the start of the stream. |
| pub fn encoded_batch( |
| &self, |
| batch: &RecordBatch, |
| dictionary_tracker: &mut DictionaryTracker, |
| write_options: &IpcWriteOptions, |
| ) -> Result<(Vec<EncodedData>, EncodedData), ArrowError> { |
| let schema = batch.schema(); |
| let mut encoded_dictionaries = Vec::with_capacity(schema.flattened_fields().len()); |
| |
| let mut dict_id = dictionary_tracker.dict_ids.clone().into_iter(); |
| |
| for (i, field) in schema.fields().iter().enumerate() { |
| let column = batch.column(i); |
| self.encode_dictionaries( |
| field, |
| column, |
| &mut encoded_dictionaries, |
| dictionary_tracker, |
| write_options, |
| &mut dict_id, |
| )?; |
| } |
| |
| let encoded_message = self.record_batch_to_bytes(batch, write_options)?; |
| Ok((encoded_dictionaries, encoded_message)) |
| } |
| |
| /// Write a `RecordBatch` into two sets of bytes, one for the header (crate::Message) and the |
| /// other for the batch's data |
| fn record_batch_to_bytes( |
| &self, |
| batch: &RecordBatch, |
| write_options: &IpcWriteOptions, |
| ) -> Result<EncodedData, ArrowError> { |
| let mut fbb = FlatBufferBuilder::new(); |
| |
| let mut nodes: Vec<crate::FieldNode> = vec![]; |
| let mut buffers: Vec<crate::Buffer> = vec![]; |
| let mut arrow_data: Vec<u8> = vec![]; |
| let mut offset = 0; |
| |
| // get the type of compression |
| let batch_compression_type = write_options.batch_compression_type; |
| |
| let compression = batch_compression_type.map(|batch_compression_type| { |
| let mut c = crate::BodyCompressionBuilder::new(&mut fbb); |
| c.add_method(crate::BodyCompressionMethod::BUFFER); |
| c.add_codec(batch_compression_type); |
| c.finish() |
| }); |
| |
| let compression_codec: Option<CompressionCodec> = |
| batch_compression_type.map(TryInto::try_into).transpose()?; |
| |
| let mut variadic_buffer_counts = vec![]; |
| |
| for array in batch.columns() { |
| let array_data = array.to_data(); |
| offset = write_array_data( |
| &array_data, |
| &mut buffers, |
| &mut arrow_data, |
| &mut nodes, |
| offset, |
| array.len(), |
| array.null_count(), |
| compression_codec, |
| write_options, |
| )?; |
| |
| append_variadic_buffer_counts(&mut variadic_buffer_counts, &array_data); |
| } |
| // pad the tail of body data |
| let len = arrow_data.len(); |
| let pad_len = pad_to_alignment(write_options.alignment, len); |
| arrow_data.extend_from_slice(&PADDING[..pad_len]); |
| |
| // write data |
| let buffers = fbb.create_vector(&buffers); |
| let nodes = fbb.create_vector(&nodes); |
| let variadic_buffer = if variadic_buffer_counts.is_empty() { |
| None |
| } else { |
| Some(fbb.create_vector(&variadic_buffer_counts)) |
| }; |
| |
| let root = { |
| let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb); |
| batch_builder.add_length(batch.num_rows() as i64); |
| batch_builder.add_nodes(nodes); |
| batch_builder.add_buffers(buffers); |
| if let Some(c) = compression { |
| batch_builder.add_compression(c); |
| } |
| |
| if let Some(v) = variadic_buffer { |
| batch_builder.add_variadicBufferCounts(v); |
| } |
| let b = batch_builder.finish(); |
| b.as_union_value() |
| }; |
| // create an crate::Message |
| let mut message = crate::MessageBuilder::new(&mut fbb); |
| message.add_version(write_options.metadata_version); |
| message.add_header_type(crate::MessageHeader::RecordBatch); |
| message.add_bodyLength(arrow_data.len() as i64); |
| message.add_header(root); |
| let root = message.finish(); |
| fbb.finish(root, None); |
| let finished_data = fbb.finished_data(); |
| |
| Ok(EncodedData { |
| ipc_message: finished_data.to_vec(), |
| arrow_data, |
| }) |
| } |
| |
| /// Write dictionary values into two sets of bytes, one for the header (crate::Message) and the |
| /// other for the data |
| fn dictionary_batch_to_bytes( |
| &self, |
| dict_id: i64, |
| array_data: &ArrayData, |
| write_options: &IpcWriteOptions, |
| ) -> Result<EncodedData, ArrowError> { |
| let mut fbb = FlatBufferBuilder::new(); |
| |
| let mut nodes: Vec<crate::FieldNode> = vec![]; |
| let mut buffers: Vec<crate::Buffer> = vec![]; |
| let mut arrow_data: Vec<u8> = vec![]; |
| |
| // get the type of compression |
| let batch_compression_type = write_options.batch_compression_type; |
| |
| let compression = batch_compression_type.map(|batch_compression_type| { |
| let mut c = crate::BodyCompressionBuilder::new(&mut fbb); |
| c.add_method(crate::BodyCompressionMethod::BUFFER); |
| c.add_codec(batch_compression_type); |
| c.finish() |
| }); |
| |
| let compression_codec: Option<CompressionCodec> = batch_compression_type |
| .map(|batch_compression_type| batch_compression_type.try_into()) |
| .transpose()?; |
| |
| write_array_data( |
| array_data, |
| &mut buffers, |
| &mut arrow_data, |
| &mut nodes, |
| 0, |
| array_data.len(), |
| array_data.null_count(), |
| compression_codec, |
| write_options, |
| )?; |
| |
| let mut variadic_buffer_counts = vec![]; |
| append_variadic_buffer_counts(&mut variadic_buffer_counts, array_data); |
| |
| // pad the tail of body data |
| let len = arrow_data.len(); |
| let pad_len = pad_to_alignment(write_options.alignment, len); |
| arrow_data.extend_from_slice(&PADDING[..pad_len]); |
| |
| // write data |
| let buffers = fbb.create_vector(&buffers); |
| let nodes = fbb.create_vector(&nodes); |
| let variadic_buffer = if variadic_buffer_counts.is_empty() { |
| None |
| } else { |
| Some(fbb.create_vector(&variadic_buffer_counts)) |
| }; |
| |
| let root = { |
| let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb); |
| batch_builder.add_length(array_data.len() as i64); |
| batch_builder.add_nodes(nodes); |
| batch_builder.add_buffers(buffers); |
| if let Some(c) = compression { |
| batch_builder.add_compression(c); |
| } |
| if let Some(v) = variadic_buffer { |
| batch_builder.add_variadicBufferCounts(v); |
| } |
| batch_builder.finish() |
| }; |
| |
| let root = { |
| let mut batch_builder = crate::DictionaryBatchBuilder::new(&mut fbb); |
| batch_builder.add_id(dict_id); |
| batch_builder.add_data(root); |
| batch_builder.finish().as_union_value() |
| }; |
| |
| let root = { |
| let mut message_builder = crate::MessageBuilder::new(&mut fbb); |
| message_builder.add_version(write_options.metadata_version); |
| message_builder.add_header_type(crate::MessageHeader::DictionaryBatch); |
| message_builder.add_bodyLength(arrow_data.len() as i64); |
| message_builder.add_header(root); |
| message_builder.finish() |
| }; |
| |
| fbb.finish(root, None); |
| let finished_data = fbb.finished_data(); |
| |
| Ok(EncodedData { |
| ipc_message: finished_data.to_vec(), |
| arrow_data, |
| }) |
| } |
| } |
| |
| fn append_variadic_buffer_counts(counts: &mut Vec<i64>, array: &ArrayData) { |
| match array.data_type() { |
| DataType::BinaryView | DataType::Utf8View => { |
| // The spec documents the counts only includes the variadic buffers, not the view/null buffers. |
| // https://arrow.apache.org/docs/format/Columnar.html#variadic-buffers |
| counts.push(array.buffers().len() as i64 - 1); |
| } |
| DataType::Dictionary(_, _) => { |
| // Do nothing |
| // Dictionary types are handled in `encode_dictionaries`. |
| } |
| _ => { |
| for child in array.child_data() { |
| append_variadic_buffer_counts(counts, child) |
| } |
| } |
| } |
| } |
| |
| pub(crate) fn unslice_run_array(arr: ArrayData) -> Result<ArrayData, ArrowError> { |
| match arr.data_type() { |
| DataType::RunEndEncoded(k, _) => match k.data_type() { |
| DataType::Int16 => { |
| Ok(into_zero_offset_run_array(RunArray::<Int16Type>::from(arr))?.into_data()) |
| } |
| DataType::Int32 => { |
| Ok(into_zero_offset_run_array(RunArray::<Int32Type>::from(arr))?.into_data()) |
| } |
| DataType::Int64 => { |
| Ok(into_zero_offset_run_array(RunArray::<Int64Type>::from(arr))?.into_data()) |
| } |
| d => unreachable!("Unexpected data type {d}"), |
| }, |
| d => Err(ArrowError::InvalidArgumentError(format!( |
| "The given array is not a run array. Data type of given array: {d}" |
| ))), |
| } |
| } |
| |
| // Returns a `RunArray` with zero offset and length matching the last value |
| // in run_ends array. |
| fn into_zero_offset_run_array<R: RunEndIndexType>( |
| run_array: RunArray<R>, |
| ) -> Result<RunArray<R>, ArrowError> { |
| let run_ends = run_array.run_ends(); |
| if run_ends.offset() == 0 && run_ends.max_value() == run_ends.len() { |
| return Ok(run_array); |
| } |
| |
| // The physical index of original run_ends array from which the `ArrayData`is sliced. |
| let start_physical_index = run_ends.get_start_physical_index(); |
| |
| // The physical index of original run_ends array until which the `ArrayData`is sliced. |
| let end_physical_index = run_ends.get_end_physical_index(); |
| |
| let physical_length = end_physical_index - start_physical_index + 1; |
| |
| // build new run_ends array by subtracting offset from run ends. |
| let offset = R::Native::usize_as(run_ends.offset()); |
| let mut builder = BufferBuilder::<R::Native>::new(physical_length); |
| for run_end_value in &run_ends.values()[start_physical_index..end_physical_index] { |
| builder.append(run_end_value.sub_wrapping(offset)); |
| } |
| builder.append(R::Native::from_usize(run_array.len()).unwrap()); |
| let new_run_ends = unsafe { |
| // Safety: |
| // The function builds a valid run_ends array and hence need not be validated. |
| ArrayDataBuilder::new(R::DATA_TYPE) |
| .len(physical_length) |
| .add_buffer(builder.finish()) |
| .build_unchecked() |
| }; |
| |
| // build new values by slicing physical indices. |
| let new_values = run_array |
| .values() |
| .slice(start_physical_index, physical_length) |
| .into_data(); |
| |
| let builder = ArrayDataBuilder::new(run_array.data_type().clone()) |
| .len(run_array.len()) |
| .add_child_data(new_run_ends) |
| .add_child_data(new_values); |
| let array_data = unsafe { |
| // Safety: |
| // This function builds a valid run array and hence can skip validation. |
| builder.build_unchecked() |
| }; |
| Ok(array_data.into()) |
| } |
| |
| /// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary |
| /// multiple times. Can optionally error if an update to an existing dictionary is attempted, which |
| /// isn't allowed in the `FileWriter`. |
| pub struct DictionaryTracker { |
| written: HashMap<i64, ArrayData>, |
| dict_ids: Vec<i64>, |
| error_on_replacement: bool, |
| preserve_dict_id: bool, |
| } |
| |
| impl DictionaryTracker { |
| /// Create a new [`DictionaryTracker`]. |
| /// |
| /// If `error_on_replacement` |
| /// is true, an error will be generated if an update to an |
| /// existing dictionary is attempted. |
| /// |
| /// If `preserve_dict_id` is true, the dictionary ID defined in the schema |
| /// is used, otherwise a unique dictionary ID will be assigned by incrementing |
| /// the last seen dictionary ID (or using `0` if no other dictionary IDs have been |
| /// seen) |
| pub fn new(error_on_replacement: bool) -> Self { |
| Self { |
| written: HashMap::new(), |
| dict_ids: Vec::new(), |
| error_on_replacement, |
| preserve_dict_id: true, |
| } |
| } |
| |
| /// Create a new [`DictionaryTracker`]. |
| /// |
| /// If `error_on_replacement` |
| /// is true, an error will be generated if an update to an |
| /// existing dictionary is attempted. |
| pub fn new_with_preserve_dict_id(error_on_replacement: bool, preserve_dict_id: bool) -> Self { |
| Self { |
| written: HashMap::new(), |
| dict_ids: Vec::new(), |
| error_on_replacement, |
| preserve_dict_id, |
| } |
| } |
| |
| /// Set the dictionary ID for `field`. |
| /// |
| /// If `preserve_dict_id` is true, this will return the `dict_id` in `field` (or panic if `field` does |
| /// not have a `dict_id` defined). |
| /// |
| /// If `preserve_dict_id` is false, this will return the value of the last `dict_id` assigned incremented by 1 |
| /// or 0 in the case where no dictionary IDs have yet been assigned |
| pub fn set_dict_id(&mut self, field: &Field) -> i64 { |
| let next = if self.preserve_dict_id { |
| field.dict_id().expect("no dict_id in field") |
| } else { |
| self.dict_ids |
| .last() |
| .copied() |
| .map(|i| i + 1) |
| .unwrap_or_default() |
| }; |
| |
| self.dict_ids.push(next); |
| next |
| } |
| |
| /// Return the sequence of dictionary IDs in the order they should be observed while |
| /// traversing the schema |
| pub fn dict_id(&mut self) -> &[i64] { |
| &self.dict_ids |
| } |
| |
| /// Keep track of the dictionary with the given ID and values. Behavior: |
| /// |
| /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate |
| /// that the dictionary was not actually inserted (because it's already been seen). |
| /// * If this ID has been written already but with different data, and this tracker is |
| /// configured to return an error, return an error. |
| /// * If the tracker has not been configured to error on replacement or this dictionary |
| /// has never been seen before, return `Ok(true)` to indicate that the dictionary was just |
| /// inserted. |
| pub fn insert(&mut self, dict_id: i64, column: &ArrayRef) -> Result<bool, ArrowError> { |
| let dict_data = column.to_data(); |
| let dict_values = &dict_data.child_data()[0]; |
| |
| // If a dictionary with this id was already emitted, check if it was the same. |
| if let Some(last) = self.written.get(&dict_id) { |
| if ArrayData::ptr_eq(&last.child_data()[0], dict_values) { |
| // Same dictionary values => no need to emit it again |
| return Ok(false); |
| } |
| if self.error_on_replacement { |
| // If error on replacement perform a logical comparison |
| if last.child_data()[0] == *dict_values { |
| // Same dictionary values => no need to emit it again |
| return Ok(false); |
| } |
| return Err(ArrowError::InvalidArgumentError( |
| "Dictionary replacement detected when writing IPC file format. \ |
| Arrow IPC files only support a single dictionary for a given field \ |
| across all batches." |
| .to_string(), |
| )); |
| } |
| } |
| |
| self.written.insert(dict_id, dict_data); |
| Ok(true) |
| } |
| } |
| |
| /// Writer for an IPC file |
| pub struct FileWriter<W: Write> { |
| /// The object to write to |
| writer: BufWriter<W>, |
| /// IPC write options |
| write_options: IpcWriteOptions, |
| /// A reference to the schema, used in validating record batches |
| schema: SchemaRef, |
| /// The number of bytes between each block of bytes, as an offset for random access |
| block_offsets: usize, |
| /// Dictionary blocks that will be written as part of the IPC footer |
| dictionary_blocks: Vec<crate::Block>, |
| /// Record blocks that will be written as part of the IPC footer |
| record_blocks: Vec<crate::Block>, |
| /// Whether the writer footer has been written, and the writer is finished |
| finished: bool, |
| /// Keeps track of dictionaries that have been written |
| dictionary_tracker: DictionaryTracker, |
| /// User level customized metadata |
| custom_metadata: HashMap<String, String>, |
| |
| data_gen: IpcDataGenerator, |
| } |
| |
| impl<W: Write> FileWriter<W> { |
| /// Try to create a new writer, with the schema written as part of the header |
| pub fn try_new(writer: W, schema: &Schema) -> Result<Self, ArrowError> { |
| let write_options = IpcWriteOptions::default(); |
| Self::try_new_with_options(writer, schema, write_options) |
| } |
| |
| /// Try to create a new writer with IpcWriteOptions |
| pub fn try_new_with_options( |
| writer: W, |
| schema: &Schema, |
| write_options: IpcWriteOptions, |
| ) -> Result<Self, ArrowError> { |
| let data_gen = IpcDataGenerator::default(); |
| let mut writer = BufWriter::new(writer); |
| // write magic to header aligned on alignment boundary |
| let pad_len = pad_to_alignment(write_options.alignment, super::ARROW_MAGIC.len()); |
| let header_size = super::ARROW_MAGIC.len() + pad_len; |
| writer.write_all(&super::ARROW_MAGIC)?; |
| writer.write_all(&PADDING[..pad_len])?; |
| // write the schema, set the written bytes to the schema + header |
| let encoded_message = data_gen.schema_to_bytes(schema, &write_options); |
| let (meta, data) = write_message(&mut writer, encoded_message, &write_options)?; |
| let preserve_dict_id = write_options.preserve_dict_id; |
| Ok(Self { |
| writer, |
| write_options, |
| schema: Arc::new(schema.clone()), |
| block_offsets: meta + data + header_size, |
| dictionary_blocks: vec![], |
| record_blocks: vec![], |
| finished: false, |
| dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id( |
| true, |
| preserve_dict_id, |
| ), |
| custom_metadata: HashMap::new(), |
| data_gen, |
| }) |
| } |
| |
| pub fn write_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) { |
| self.custom_metadata.insert(key.into(), value.into()); |
| } |
| |
| /// Write a record batch to the file |
| pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { |
| if self.finished { |
| return Err(ArrowError::IpcError( |
| "Cannot write record batch to file writer as it is closed".to_string(), |
| )); |
| } |
| |
| let (encoded_dictionaries, encoded_message) = self.data_gen.encoded_batch( |
| batch, |
| &mut self.dictionary_tracker, |
| &self.write_options, |
| )?; |
| |
| for encoded_dictionary in encoded_dictionaries { |
| let (meta, data) = |
| write_message(&mut self.writer, encoded_dictionary, &self.write_options)?; |
| |
| let block = crate::Block::new(self.block_offsets as i64, meta as i32, data as i64); |
| self.dictionary_blocks.push(block); |
| self.block_offsets += meta + data; |
| } |
| |
| let (meta, data) = write_message(&mut self.writer, encoded_message, &self.write_options)?; |
| // add a record block for the footer |
| let block = crate::Block::new( |
| self.block_offsets as i64, |
| meta as i32, // TODO: is this still applicable? |
| data as i64, |
| ); |
| self.record_blocks.push(block); |
| self.block_offsets += meta + data; |
| Ok(()) |
| } |
| |
| /// Write footer and closing tag, then mark the writer as done |
| pub fn finish(&mut self) -> Result<(), ArrowError> { |
| if self.finished { |
| return Err(ArrowError::IpcError( |
| "Cannot write footer to file writer as it is closed".to_string(), |
| )); |
| } |
| |
| // write EOS |
| write_continuation(&mut self.writer, &self.write_options, 0)?; |
| |
| let mut fbb = FlatBufferBuilder::new(); |
| let dictionaries = fbb.create_vector(&self.dictionary_blocks); |
| let record_batches = fbb.create_vector(&self.record_blocks); |
| let schema = crate::convert::schema_to_fb_offset(&mut fbb, &self.schema); |
| let fb_custom_metadata = (!self.custom_metadata.is_empty()) |
| .then(|| crate::convert::metadata_to_fb(&mut fbb, &self.custom_metadata)); |
| |
| let root = { |
| let mut footer_builder = crate::FooterBuilder::new(&mut fbb); |
| footer_builder.add_version(self.write_options.metadata_version); |
| footer_builder.add_schema(schema); |
| footer_builder.add_dictionaries(dictionaries); |
| footer_builder.add_recordBatches(record_batches); |
| if let Some(fb_custom_metadata) = fb_custom_metadata { |
| footer_builder.add_custom_metadata(fb_custom_metadata); |
| } |
| footer_builder.finish() |
| }; |
| fbb.finish(root, None); |
| let footer_data = fbb.finished_data(); |
| self.writer.write_all(footer_data)?; |
| self.writer |
| .write_all(&(footer_data.len() as i32).to_le_bytes())?; |
| self.writer.write_all(&super::ARROW_MAGIC)?; |
| self.writer.flush()?; |
| self.finished = true; |
| |
| Ok(()) |
| } |
| |
| /// Returns the arrow [`SchemaRef`] for this arrow file. |
| pub fn schema(&self) -> &SchemaRef { |
| &self.schema |
| } |
| |
| /// Gets a reference to the underlying writer. |
| pub fn get_ref(&self) -> &W { |
| self.writer.get_ref() |
| } |
| |
| /// Gets a mutable reference to the underlying writer. |
| /// |
| /// It is inadvisable to directly write to the underlying writer. |
| pub fn get_mut(&mut self) -> &mut W { |
| self.writer.get_mut() |
| } |
| |
| /// Unwraps the BufWriter housed in FileWriter.writer, returning the underlying |
| /// writer |
| /// |
| /// The buffer is flushed and the FileWriter is finished before returning the |
| /// writer. |
| pub fn into_inner(mut self) -> Result<W, ArrowError> { |
| if !self.finished { |
| self.finish()?; |
| } |
| self.writer.into_inner().map_err(ArrowError::from) |
| } |
| } |
| |
| impl<W: Write> RecordBatchWriter for FileWriter<W> { |
| fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { |
| self.write(batch) |
| } |
| |
| fn close(mut self) -> Result<(), ArrowError> { |
| self.finish() |
| } |
| } |
| |
| /// Writer for an IPC stream |
| pub struct StreamWriter<W: Write> { |
| /// The object to write to |
| writer: BufWriter<W>, |
| /// IPC write options |
| write_options: IpcWriteOptions, |
| /// Whether the writer footer has been written, and the writer is finished |
| finished: bool, |
| /// Keeps track of dictionaries that have been written |
| dictionary_tracker: DictionaryTracker, |
| |
| data_gen: IpcDataGenerator, |
| } |
| |
| impl<W: Write> StreamWriter<W> { |
| /// Try to create a new writer, with the schema written as part of the header |
| pub fn try_new(writer: W, schema: &Schema) -> Result<Self, ArrowError> { |
| let write_options = IpcWriteOptions::default(); |
| Self::try_new_with_options(writer, schema, write_options) |
| } |
| |
| pub fn try_new_with_options( |
| writer: W, |
| schema: &Schema, |
| write_options: IpcWriteOptions, |
| ) -> Result<Self, ArrowError> { |
| let data_gen = IpcDataGenerator::default(); |
| let mut writer = BufWriter::new(writer); |
| // write the schema, set the written bytes to the schema |
| let encoded_message = data_gen.schema_to_bytes(schema, &write_options); |
| write_message(&mut writer, encoded_message, &write_options)?; |
| let preserve_dict_id = write_options.preserve_dict_id; |
| Ok(Self { |
| writer, |
| write_options, |
| finished: false, |
| dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id( |
| false, |
| preserve_dict_id, |
| ), |
| data_gen, |
| }) |
| } |
| |
| /// Write a record batch to the stream |
| pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { |
| if self.finished { |
| return Err(ArrowError::IpcError( |
| "Cannot write record batch to stream writer as it is closed".to_string(), |
| )); |
| } |
| |
| let (encoded_dictionaries, encoded_message) = self |
| .data_gen |
| .encoded_batch(batch, &mut self.dictionary_tracker, &self.write_options) |
| .expect("StreamWriter is configured to not error on dictionary replacement"); |
| |
| for encoded_dictionary in encoded_dictionaries { |
| write_message(&mut self.writer, encoded_dictionary, &self.write_options)?; |
| } |
| |
| write_message(&mut self.writer, encoded_message, &self.write_options)?; |
| Ok(()) |
| } |
| |
| /// Write continuation bytes, and mark the stream as done |
| pub fn finish(&mut self) -> Result<(), ArrowError> { |
| if self.finished { |
| return Err(ArrowError::IpcError( |
| "Cannot write footer to stream writer as it is closed".to_string(), |
| )); |
| } |
| |
| write_continuation(&mut self.writer, &self.write_options, 0)?; |
| |
| self.finished = true; |
| |
| Ok(()) |
| } |
| |
| /// Gets a reference to the underlying writer. |
| pub fn get_ref(&self) -> &W { |
| self.writer.get_ref() |
| } |
| |
| /// Gets a mutable reference to the underlying writer. |
| /// |
| /// It is inadvisable to directly write to the underlying writer. |
| pub fn get_mut(&mut self) -> &mut W { |
| self.writer.get_mut() |
| } |
| |
| /// Unwraps the BufWriter housed in StreamWriter.writer, returning the underlying |
| /// writer |
| /// |
| /// The buffer is flushed and the StreamWriter is finished before returning the |
| /// writer. |
| /// |
| /// # Errors |
| /// |
| /// An ['Err'] may be returned if an error occurs while finishing the StreamWriter |
| /// or while flushing the buffer. |
| /// |
| /// # Example |
| /// |
| /// ``` |
| /// # use arrow_ipc::writer::{StreamWriter, IpcWriteOptions}; |
| /// # use arrow_ipc::MetadataVersion; |
| /// # use arrow_schema::{ArrowError, Schema}; |
| /// # fn main() -> Result<(), ArrowError> { |
| /// // The result we expect from an empty schema |
| /// let expected = vec![ |
| /// 255, 255, 255, 255, 48, 0, 0, 0, |
| /// 16, 0, 0, 0, 0, 0, 10, 0, |
| /// 12, 0, 10, 0, 9, 0, 4, 0, |
| /// 10, 0, 0, 0, 16, 0, 0, 0, |
| /// 0, 1, 4, 0, 8, 0, 8, 0, |
| /// 0, 0, 4, 0, 8, 0, 0, 0, |
| /// 4, 0, 0, 0, 0, 0, 0, 0, |
| /// 255, 255, 255, 255, 0, 0, 0, 0 |
| /// ]; |
| /// |
| /// let schema = Schema::empty(); |
| /// let buffer: Vec<u8> = Vec::new(); |
| /// let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5)?; |
| /// let stream_writer = StreamWriter::try_new_with_options(buffer, &schema, options)?; |
| /// |
| /// assert_eq!(stream_writer.into_inner()?, expected); |
| /// # Ok(()) |
| /// # } |
| /// ``` |
| pub fn into_inner(mut self) -> Result<W, ArrowError> { |
| if !self.finished { |
| self.finish()?; |
| } |
| self.writer.into_inner().map_err(ArrowError::from) |
| } |
| } |
| |
| impl<W: Write> RecordBatchWriter for StreamWriter<W> { |
| fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { |
| self.write(batch) |
| } |
| |
| fn close(mut self) -> Result<(), ArrowError> { |
| self.finish() |
| } |
| } |
| |
| /// Stores the encoded data, which is an crate::Message, and optional Arrow data |
| pub struct EncodedData { |
| /// An encoded crate::Message |
| pub ipc_message: Vec<u8>, |
| /// Arrow buffers to be written, should be an empty vec for schema messages |
| pub arrow_data: Vec<u8>, |
| } |
| /// Write a message's IPC data and buffers, returning metadata and buffer data lengths written |
| pub fn write_message<W: Write>( |
| mut writer: W, |
| encoded: EncodedData, |
| write_options: &IpcWriteOptions, |
| ) -> Result<(usize, usize), ArrowError> { |
| let arrow_data_len = encoded.arrow_data.len(); |
| if arrow_data_len % usize::from(write_options.alignment) != 0 { |
| return Err(ArrowError::MemoryError( |
| "Arrow data not aligned".to_string(), |
| )); |
| } |
| |
| let a = usize::from(write_options.alignment - 1); |
| let buffer = encoded.ipc_message; |
| let flatbuf_size = buffer.len(); |
| let prefix_size = if write_options.write_legacy_ipc_format { |
| 4 |
| } else { |
| 8 |
| }; |
| let aligned_size = (flatbuf_size + prefix_size + a) & !a; |
| let padding_bytes = aligned_size - flatbuf_size - prefix_size; |
| |
| write_continuation( |
| &mut writer, |
| write_options, |
| (aligned_size - prefix_size) as i32, |
| )?; |
| |
| // write the flatbuf |
| if flatbuf_size > 0 { |
| writer.write_all(&buffer)?; |
| } |
| // write padding |
| writer.write_all(&PADDING[..padding_bytes])?; |
| |
| // write arrow data |
| let body_len = if arrow_data_len > 0 { |
| write_body_buffers(&mut writer, &encoded.arrow_data, write_options.alignment)? |
| } else { |
| 0 |
| }; |
| |
| Ok((aligned_size, body_len)) |
| } |
| |
| fn write_body_buffers<W: Write>( |
| mut writer: W, |
| data: &[u8], |
| alignment: u8, |
| ) -> Result<usize, ArrowError> { |
| let len = data.len(); |
| let pad_len = pad_to_alignment(alignment, len); |
| let total_len = len + pad_len; |
| |
| // write body buffer |
| writer.write_all(data)?; |
| if pad_len > 0 { |
| writer.write_all(&PADDING[..pad_len])?; |
| } |
| |
| writer.flush()?; |
| Ok(total_len) |
| } |
| |
| /// Write a record batch to the writer, writing the message size before the message |
| /// if the record batch is being written to a stream |
| fn write_continuation<W: Write>( |
| mut writer: W, |
| write_options: &IpcWriteOptions, |
| total_len: i32, |
| ) -> Result<usize, ArrowError> { |
| let mut written = 8; |
| |
| // the version of the writer determines whether continuation markers should be added |
| match write_options.metadata_version { |
| crate::MetadataVersion::V1 | crate::MetadataVersion::V2 | crate::MetadataVersion::V3 => { |
| unreachable!("Options with the metadata version cannot be created") |
| } |
| crate::MetadataVersion::V4 => { |
| if !write_options.write_legacy_ipc_format { |
| // v0.15.0 format |
| writer.write_all(&CONTINUATION_MARKER)?; |
| written = 4; |
| } |
| writer.write_all(&total_len.to_le_bytes()[..])?; |
| } |
| crate::MetadataVersion::V5 => { |
| // write continuation marker and message length |
| writer.write_all(&CONTINUATION_MARKER)?; |
| writer.write_all(&total_len.to_le_bytes()[..])?; |
| } |
| z => panic!("Unsupported crate::MetadataVersion {z:?}"), |
| }; |
| |
| writer.flush()?; |
| |
| Ok(written) |
| } |
| |
| /// In V4, null types have no validity bitmap |
| /// In V5 and later, null and union types have no validity bitmap |
| /// Run end encoded type has no validity bitmap. |
| fn has_validity_bitmap(data_type: &DataType, write_options: &IpcWriteOptions) -> bool { |
| if write_options.metadata_version < crate::MetadataVersion::V5 { |
| !matches!(data_type, DataType::Null) |
| } else { |
| !matches!( |
| data_type, |
| DataType::Null | DataType::Union(_, _) | DataType::RunEndEncoded(_, _) |
| ) |
| } |
| } |
| |
| /// Whether to truncate the buffer |
| #[inline] |
| fn buffer_need_truncate( |
| array_offset: usize, |
| buffer: &Buffer, |
| spec: &BufferSpec, |
| min_length: usize, |
| ) -> bool { |
| spec != &BufferSpec::AlwaysNull && (array_offset != 0 || min_length < buffer.len()) |
| } |
| |
| /// Returns byte width for a buffer spec. Only for `BufferSpec::FixedWidth`. |
| #[inline] |
| fn get_buffer_element_width(spec: &BufferSpec) -> usize { |
| match spec { |
| BufferSpec::FixedWidth { byte_width, .. } => *byte_width, |
| _ => 0, |
| } |
| } |
| |
| /// Common functionality for re-encoding offsets. Returns the new offsets as well as |
| /// original start offset and length for use in slicing child data. |
| fn reencode_offsets<O: OffsetSizeTrait>( |
| offsets: &Buffer, |
| data: &ArrayData, |
| ) -> (Buffer, usize, usize) { |
| let offsets_slice: &[O] = offsets.typed_data::<O>(); |
| let offset_slice = &offsets_slice[data.offset()..data.offset() + data.len() + 1]; |
| |
| let start_offset = offset_slice.first().unwrap(); |
| let end_offset = offset_slice.last().unwrap(); |
| |
| let offsets = match start_offset.as_usize() { |
| 0 => offsets.clone(), |
| _ => offset_slice.iter().map(|x| *x - *start_offset).collect(), |
| }; |
| |
| let start_offset = start_offset.as_usize(); |
| let end_offset = end_offset.as_usize(); |
| |
| (offsets, start_offset, end_offset - start_offset) |
| } |
| |
| /// Returns the values and offsets [`Buffer`] for a ByteArray with offset type `O` |
| /// |
| /// In particular, this handles re-encoding the offsets if they don't start at `0`, |
| /// slicing the values buffer as appropriate. This helps reduce the encoded |
| /// size of sliced arrays, as values that have been sliced away are not encoded |
| fn get_byte_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer, Buffer) { |
| if data.is_empty() { |
| return (MutableBuffer::new(0).into(), MutableBuffer::new(0).into()); |
| } |
| |
| let (offsets, original_start_offset, len) = reencode_offsets::<O>(&data.buffers()[0], data); |
| let values = data.buffers()[1].slice_with_length(original_start_offset, len); |
| (offsets, values) |
| } |
| |
| /// Similar logic as [`get_byte_array_buffers()`] but slices the child array instead |
| /// of a values buffer. |
| fn get_list_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer, ArrayData) { |
| if data.is_empty() { |
| return ( |
| MutableBuffer::new(0).into(), |
| data.child_data()[0].slice(0, 0), |
| ); |
| } |
| |
| let (offsets, original_start_offset, len) = reencode_offsets::<O>(&data.buffers()[0], data); |
| let child_data = data.child_data()[0].slice(original_start_offset, len); |
| (offsets, child_data) |
| } |
| |
| /// Write array data to a vector of bytes |
| #[allow(clippy::too_many_arguments)] |
| fn write_array_data( |
| array_data: &ArrayData, |
| buffers: &mut Vec<crate::Buffer>, |
| arrow_data: &mut Vec<u8>, |
| nodes: &mut Vec<crate::FieldNode>, |
| offset: i64, |
| num_rows: usize, |
| null_count: usize, |
| compression_codec: Option<CompressionCodec>, |
| write_options: &IpcWriteOptions, |
| ) -> Result<i64, ArrowError> { |
| let mut offset = offset; |
| if !matches!(array_data.data_type(), DataType::Null) { |
| nodes.push(crate::FieldNode::new(num_rows as i64, null_count as i64)); |
| } else { |
| // NullArray's null_count equals to len, but the `null_count` passed in is from ArrayData |
| // where null_count is always 0. |
| nodes.push(crate::FieldNode::new(num_rows as i64, num_rows as i64)); |
| } |
| if has_validity_bitmap(array_data.data_type(), write_options) { |
| // write null buffer if exists |
| let null_buffer = match array_data.nulls() { |
| None => { |
| // create a buffer and fill it with valid bits |
| let num_bytes = bit_util::ceil(num_rows, 8); |
| let buffer = MutableBuffer::new(num_bytes); |
| let buffer = buffer.with_bitset(num_bytes, true); |
| buffer.into() |
| } |
| Some(buffer) => buffer.inner().sliced(), |
| }; |
| |
| offset = write_buffer( |
| null_buffer.as_slice(), |
| buffers, |
| arrow_data, |
| offset, |
| compression_codec, |
| write_options.alignment, |
| )?; |
| } |
| |
| let data_type = array_data.data_type(); |
| if matches!(data_type, DataType::Binary | DataType::Utf8) { |
| let (offsets, values) = get_byte_array_buffers::<i32>(array_data); |
| for buffer in [offsets, values] { |
| offset = write_buffer( |
| buffer.as_slice(), |
| buffers, |
| arrow_data, |
| offset, |
| compression_codec, |
| write_options.alignment, |
| )?; |
| } |
| } else if matches!(data_type, DataType::BinaryView | DataType::Utf8View) { |
| // Slicing the views buffer is safe and easy, |
| // but pruning unneeded data buffers is much more nuanced since it's complicated to prove that no views reference the pruned buffers |
| // |
| // Current implementation just serialize the raw arrays as given and not try to optimize anything. |
| // If users wants to "compact" the arrays prior to sending them over IPC, |
| // they should consider the gc API suggested in #5513 |
| for buffer in array_data.buffers() { |
| offset = write_buffer( |
| buffer.as_slice(), |
| buffers, |
| arrow_data, |
| offset, |
| compression_codec, |
| write_options.alignment, |
| )?; |
| } |
| } else if matches!(data_type, DataType::LargeBinary | DataType::LargeUtf8) { |
| let (offsets, values) = get_byte_array_buffers::<i64>(array_data); |
| for buffer in [offsets, values] { |
| offset = write_buffer( |
| buffer.as_slice(), |
| buffers, |
| arrow_data, |
| offset, |
| compression_codec, |
| write_options.alignment, |
| )?; |
| } |
| } else if DataType::is_numeric(data_type) |
| || DataType::is_temporal(data_type) |
| || matches!( |
| array_data.data_type(), |
| DataType::FixedSizeBinary(_) | DataType::Dictionary(_, _) |
| ) |
| { |
| // Truncate values |
| assert_eq!(array_data.buffers().len(), 1); |
| |
| let buffer = &array_data.buffers()[0]; |
| let layout = layout(data_type); |
| let spec = &layout.buffers[0]; |
| |
| let byte_width = get_buffer_element_width(spec); |
| let min_length = array_data.len() * byte_width; |
| let buffer_slice = if buffer_need_truncate(array_data.offset(), buffer, spec, min_length) { |
| let byte_offset = array_data.offset() * byte_width; |
| let buffer_length = min(min_length, buffer.len() - byte_offset); |
| &buffer.as_slice()[byte_offset..(byte_offset + buffer_length)] |
| } else { |
| buffer.as_slice() |
| }; |
| offset = write_buffer( |
| buffer_slice, |
| buffers, |
| arrow_data, |
| offset, |
| compression_codec, |
| write_options.alignment, |
| )?; |
| } else if matches!(data_type, DataType::Boolean) { |
| // Bools are special because the payload (= 1 bit) is smaller than the physical container elements (= bytes). |
| // The array data may not start at the physical boundary of the underlying buffer, so we need to shift bits around. |
| assert_eq!(array_data.buffers().len(), 1); |
| |
| let buffer = &array_data.buffers()[0]; |
| let buffer = buffer.bit_slice(array_data.offset(), array_data.len()); |
| offset = write_buffer( |
| &buffer, |
| buffers, |
| arrow_data, |
| offset, |
| compression_codec, |
| write_options.alignment, |
| )?; |
| } else if matches!( |
| data_type, |
| DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _) |
| ) { |
| assert_eq!(array_data.buffers().len(), 1); |
| assert_eq!(array_data.child_data().len(), 1); |
| |
| // Truncate offsets and the child data to avoid writing unnecessary data |
| let (offsets, sliced_child_data) = match data_type { |
| DataType::List(_) => get_list_array_buffers::<i32>(array_data), |
| DataType::Map(_, _) => get_list_array_buffers::<i32>(array_data), |
| DataType::LargeList(_) => get_list_array_buffers::<i64>(array_data), |
| _ => unreachable!(), |
| }; |
| offset = write_buffer( |
| offsets.as_slice(), |
| buffers, |
| arrow_data, |
| offset, |
| compression_codec, |
| write_options.alignment, |
| )?; |
| offset = write_array_data( |
| &sliced_child_data, |
| buffers, |
| arrow_data, |
| nodes, |
| offset, |
| sliced_child_data.len(), |
| sliced_child_data.null_count(), |
| compression_codec, |
| write_options, |
| )?; |
| return Ok(offset); |
| } else { |
| for buffer in array_data.buffers() { |
| offset = write_buffer( |
| buffer, |
| buffers, |
| arrow_data, |
| offset, |
| compression_codec, |
| write_options.alignment, |
| )?; |
| } |
| } |
| |
| match array_data.data_type() { |
| DataType::Dictionary(_, _) => {} |
| DataType::RunEndEncoded(_, _) => { |
| // unslice the run encoded array. |
| let arr = unslice_run_array(array_data.clone())?; |
| // recursively write out nested structures |
| for data_ref in arr.child_data() { |
| // write the nested data (e.g list data) |
| offset = write_array_data( |
| data_ref, |
| buffers, |
| arrow_data, |
| nodes, |
| offset, |
| data_ref.len(), |
| data_ref.null_count(), |
| compression_codec, |
| write_options, |
| )?; |
| } |
| } |
| _ => { |
| // recursively write out nested structures |
| for data_ref in array_data.child_data() { |
| // write the nested data (e.g list data) |
| offset = write_array_data( |
| data_ref, |
| buffers, |
| arrow_data, |
| nodes, |
| offset, |
| data_ref.len(), |
| data_ref.null_count(), |
| compression_codec, |
| write_options, |
| )?; |
| } |
| } |
| } |
| Ok(offset) |
| } |
| |
| /// Write a buffer into `arrow_data`, a vector of bytes, and adds its |
| /// [`crate::Buffer`] to `buffers`. Returns the new offset in `arrow_data` |
| /// |
| /// |
| /// From <https://github.com/apache/arrow/blob/6a936c4ff5007045e86f65f1a6b6c3c955ad5103/format/Message.fbs#L58> |
| /// Each constituent buffer is first compressed with the indicated |
| /// compressor, and then written with the uncompressed length in the first 8 |
| /// bytes as a 64-bit little-endian signed integer followed by the compressed |
| /// buffer bytes (and then padding as required by the protocol). The |
| /// uncompressed length may be set to -1 to indicate that the data that |
| /// follows is not compressed, which can be useful for cases where |
| /// compression does not yield appreciable savings. |
| fn write_buffer( |
| buffer: &[u8], // input |
| buffers: &mut Vec<crate::Buffer>, // output buffer descriptors |
| arrow_data: &mut Vec<u8>, // output stream |
| offset: i64, // current output stream offset |
| compression_codec: Option<CompressionCodec>, |
| alignment: u8, |
| ) -> Result<i64, ArrowError> { |
| let len: i64 = match compression_codec { |
| Some(compressor) => compressor.compress_to_vec(buffer, arrow_data)?, |
| None => { |
| arrow_data.extend_from_slice(buffer); |
| buffer.len() |
| } |
| } |
| .try_into() |
| .map_err(|e| { |
| ArrowError::InvalidArgumentError(format!("Could not convert compressed size to i64: {e}")) |
| })?; |
| |
| // make new index entry |
| buffers.push(crate::Buffer::new(offset, len)); |
| // padding and make offset aligned |
| let pad_len = pad_to_alignment(alignment, len as usize); |
| arrow_data.extend_from_slice(&PADDING[..pad_len]); |
| |
| Ok(offset + len + (pad_len as i64)) |
| } |
| |
| const PADDING: [u8; 64] = [0; 64]; |
| |
| /// Calculate an alignment boundary and return the number of bytes needed to pad to the alignment boundary |
| #[inline] |
| fn pad_to_alignment(alignment: u8, len: usize) -> usize { |
| let a = usize::from(alignment - 1); |
| ((len + a) & !a) - len |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use std::io::Cursor; |
| use std::io::Seek; |
| |
| use arrow_array::builder::GenericListBuilder; |
| use arrow_array::builder::MapBuilder; |
| use arrow_array::builder::UnionBuilder; |
| use arrow_array::builder::{PrimitiveRunBuilder, UInt32Builder}; |
| use arrow_array::types::*; |
| use arrow_buffer::ScalarBuffer; |
| |
| use crate::convert::fb_to_schema; |
| use crate::reader::*; |
| use crate::root_as_footer; |
| use crate::MetadataVersion; |
| |
| use super::*; |
| |
| fn serialize_file(rb: &RecordBatch) -> Vec<u8> { |
| let mut writer = FileWriter::try_new(vec![], rb.schema_ref()).unwrap(); |
| writer.write(rb).unwrap(); |
| writer.finish().unwrap(); |
| writer.into_inner().unwrap() |
| } |
| |
| fn deserialize_file(bytes: Vec<u8>) -> RecordBatch { |
| let mut reader = FileReader::try_new(Cursor::new(bytes), None).unwrap(); |
| reader.next().unwrap().unwrap() |
| } |
| |
| fn serialize_stream(record: &RecordBatch) -> Vec<u8> { |
| // Use 8-byte alignment so that the various `truncate_*` tests can be compactly written, |
| // without needing to construct a giant array to spill over the 64-byte default alignment |
| // boundary. |
| const IPC_ALIGNMENT: usize = 8; |
| |
| let mut stream_writer = StreamWriter::try_new_with_options( |
| vec![], |
| record.schema_ref(), |
| IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(), |
| ) |
| .unwrap(); |
| stream_writer.write(record).unwrap(); |
| stream_writer.finish().unwrap(); |
| stream_writer.into_inner().unwrap() |
| } |
| |
| fn deserialize_stream(bytes: Vec<u8>) -> RecordBatch { |
| let mut stream_reader = StreamReader::try_new(Cursor::new(bytes), None).unwrap(); |
| stream_reader.next().unwrap().unwrap() |
| } |
| |
| #[test] |
| #[cfg(feature = "lz4")] |
| fn test_write_empty_record_batch_lz4_compression() { |
| let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]); |
| let values: Vec<Option<i32>> = vec![]; |
| let array = Int32Array::from(values); |
| let record_batch = |
| RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap(); |
| |
| let mut file = tempfile::tempfile().unwrap(); |
| |
| { |
| let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5) |
| .unwrap() |
| .try_with_compression(Some(crate::CompressionType::LZ4_FRAME)) |
| .unwrap(); |
| |
| let mut writer = |
| FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap(); |
| writer.write(&record_batch).unwrap(); |
| writer.finish().unwrap(); |
| } |
| file.rewind().unwrap(); |
| { |
| // read file |
| let reader = FileReader::try_new(file, None).unwrap(); |
| for read_batch in reader { |
| read_batch |
| .unwrap() |
| .columns() |
| .iter() |
| .zip(record_batch.columns()) |
| .for_each(|(a, b)| { |
| assert_eq!(a.data_type(), b.data_type()); |
| assert_eq!(a.len(), b.len()); |
| assert_eq!(a.null_count(), b.null_count()); |
| }); |
| } |
| } |
| } |
| |
| #[test] |
| #[cfg(feature = "lz4")] |
| fn test_write_file_with_lz4_compression() { |
| let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]); |
| let values: Vec<Option<i32>> = vec![Some(12), Some(1)]; |
| let array = Int32Array::from(values); |
| let record_batch = |
| RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap(); |
| |
| let mut file = tempfile::tempfile().unwrap(); |
| { |
| let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5) |
| .unwrap() |
| .try_with_compression(Some(crate::CompressionType::LZ4_FRAME)) |
| .unwrap(); |
| |
| let mut writer = |
| FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap(); |
| writer.write(&record_batch).unwrap(); |
| writer.finish().unwrap(); |
| } |
| file.rewind().unwrap(); |
| { |
| // read file |
| let reader = FileReader::try_new(file, None).unwrap(); |
| for read_batch in reader { |
| read_batch |
| .unwrap() |
| .columns() |
| .iter() |
| .zip(record_batch.columns()) |
| .for_each(|(a, b)| { |
| assert_eq!(a.data_type(), b.data_type()); |
| assert_eq!(a.len(), b.len()); |
| assert_eq!(a.null_count(), b.null_count()); |
| }); |
| } |
| } |
| } |
| |
| #[test] |
| #[cfg(feature = "zstd")] |
| fn test_write_file_with_zstd_compression() { |
| let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]); |
| let values: Vec<Option<i32>> = vec![Some(12), Some(1)]; |
| let array = Int32Array::from(values); |
| let record_batch = |
| RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap(); |
| let mut file = tempfile::tempfile().unwrap(); |
| { |
| let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5) |
| .unwrap() |
| .try_with_compression(Some(crate::CompressionType::ZSTD)) |
| .unwrap(); |
| |
| let mut writer = |
| FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap(); |
| writer.write(&record_batch).unwrap(); |
| writer.finish().unwrap(); |
| } |
| file.rewind().unwrap(); |
| { |
| // read file |
| let reader = FileReader::try_new(file, None).unwrap(); |
| for read_batch in reader { |
| read_batch |
| .unwrap() |
| .columns() |
| .iter() |
| .zip(record_batch.columns()) |
| .for_each(|(a, b)| { |
| assert_eq!(a.data_type(), b.data_type()); |
| assert_eq!(a.len(), b.len()); |
| assert_eq!(a.null_count(), b.null_count()); |
| }); |
| } |
| } |
| } |
| |
| #[test] |
| fn test_write_file() { |
| let schema = Schema::new(vec![Field::new("field1", DataType::UInt32, true)]); |
| let values: Vec<Option<u32>> = vec![ |
| Some(999), |
| None, |
| Some(235), |
| Some(123), |
| None, |
| None, |
| None, |
| None, |
| None, |
| ]; |
| let array1 = UInt32Array::from(values); |
| let batch = |
| RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array1) as ArrayRef]) |
| .unwrap(); |
| let mut file = tempfile::tempfile().unwrap(); |
| { |
| let mut writer = FileWriter::try_new(&mut file, &schema).unwrap(); |
| |
| writer.write(&batch).unwrap(); |
| writer.finish().unwrap(); |
| } |
| file.rewind().unwrap(); |
| |
| { |
| let mut reader = FileReader::try_new(file, None).unwrap(); |
| while let Some(Ok(read_batch)) = reader.next() { |
| read_batch |
| .columns() |
| .iter() |
| .zip(batch.columns()) |
| .for_each(|(a, b)| { |
| assert_eq!(a.data_type(), b.data_type()); |
| assert_eq!(a.len(), b.len()); |
| assert_eq!(a.null_count(), b.null_count()); |
| }); |
| } |
| } |
| } |
| |
| fn write_null_file(options: IpcWriteOptions) { |
| let schema = Schema::new(vec![ |
| Field::new("nulls", DataType::Null, true), |
| Field::new("int32s", DataType::Int32, false), |
| Field::new("nulls2", DataType::Null, true), |
| Field::new("f64s", DataType::Float64, false), |
| ]); |
| let array1 = NullArray::new(32); |
| let array2 = Int32Array::from(vec![1; 32]); |
| let array3 = NullArray::new(32); |
| let array4 = Float64Array::from(vec![f64::NAN; 32]); |
| let batch = RecordBatch::try_new( |
| Arc::new(schema.clone()), |
| vec![ |
| Arc::new(array1) as ArrayRef, |
| Arc::new(array2) as ArrayRef, |
| Arc::new(array3) as ArrayRef, |
| Arc::new(array4) as ArrayRef, |
| ], |
| ) |
| .unwrap(); |
| let mut file = tempfile::tempfile().unwrap(); |
| { |
| let mut writer = FileWriter::try_new_with_options(&mut file, &schema, options).unwrap(); |
| |
| writer.write(&batch).unwrap(); |
| writer.finish().unwrap(); |
| } |
| |
| file.rewind().unwrap(); |
| |
| { |
| let reader = FileReader::try_new(file, None).unwrap(); |
| reader.for_each(|maybe_batch| { |
| maybe_batch |
| .unwrap() |
| .columns() |
| .iter() |
| .zip(batch.columns()) |
| .for_each(|(a, b)| { |
| assert_eq!(a.data_type(), b.data_type()); |
| assert_eq!(a.len(), b.len()); |
| assert_eq!(a.null_count(), b.null_count()); |
| }); |
| }); |
| } |
| } |
| #[test] |
| fn test_write_null_file_v4() { |
| write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap()); |
| write_null_file(IpcWriteOptions::try_new(8, true, MetadataVersion::V4).unwrap()); |
| write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V4).unwrap()); |
| write_null_file(IpcWriteOptions::try_new(64, true, MetadataVersion::V4).unwrap()); |
| } |
| |
| #[test] |
| fn test_write_null_file_v5() { |
| write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()); |
| write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V5).unwrap()); |
| } |
| |
| #[test] |
| fn track_union_nested_dict() { |
| let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect(); |
| |
| let array = Arc::new(inner) as ArrayRef; |
| |
| // Dict field with id 2 |
| let dctfield = Field::new_dict("dict", array.data_type().clone(), false, 2, false); |
| let union_fields = [(0, Arc::new(dctfield))].into_iter().collect(); |
| |
| let types = [0, 0, 0].into_iter().collect::<ScalarBuffer<i8>>(); |
| let offsets = [0, 1, 2].into_iter().collect::<ScalarBuffer<i32>>(); |
| |
| let union = UnionArray::try_new(union_fields, types, Some(offsets), vec![array]).unwrap(); |
| |
| let schema = Arc::new(Schema::new(vec![Field::new( |
| "union", |
| union.data_type().clone(), |
| false, |
| )])); |
| |
| let batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap(); |
| |
| let gen = IpcDataGenerator {}; |
| let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); |
| gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) |
| .unwrap(); |
| |
| // The encoder will assign dict IDs itself to ensure uniqueness and ignore the dict ID in the schema |
| // so we expect the dict will be keyed to 0 |
| assert!(dict_tracker.written.contains_key(&2)); |
| } |
| |
| #[test] |
| fn track_struct_nested_dict() { |
| let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect(); |
| |
| let array = Arc::new(inner) as ArrayRef; |
| |
| // Dict field with id 2 |
| let dctfield = Arc::new(Field::new_dict( |
| "dict", |
| array.data_type().clone(), |
| false, |
| 2, |
| false, |
| )); |
| |
| let s = StructArray::from(vec![(dctfield, array)]); |
| let struct_array = Arc::new(s) as ArrayRef; |
| |
| let schema = Arc::new(Schema::new(vec![Field::new( |
| "struct", |
| struct_array.data_type().clone(), |
| false, |
| )])); |
| |
| let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap(); |
| |
| let gen = IpcDataGenerator {}; |
| let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); |
| gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) |
| .unwrap(); |
| |
| assert!(dict_tracker.written.contains_key(&2)); |
| } |
| |
| fn write_union_file(options: IpcWriteOptions) { |
| let schema = Schema::new(vec![Field::new_union( |
| "union", |
| vec![0, 1], |
| vec![ |
| Field::new("a", DataType::Int32, false), |
| Field::new("c", DataType::Float64, false), |
| ], |
| UnionMode::Sparse, |
| )]); |
| let mut builder = UnionBuilder::with_capacity_sparse(5); |
| builder.append::<Int32Type>("a", 1).unwrap(); |
| builder.append_null::<Int32Type>("a").unwrap(); |
| builder.append::<Float64Type>("c", 3.0).unwrap(); |
| builder.append_null::<Float64Type>("c").unwrap(); |
| builder.append::<Int32Type>("a", 4).unwrap(); |
| let union = builder.build().unwrap(); |
| |
| let batch = |
| RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(union) as ArrayRef]) |
| .unwrap(); |
| |
| let mut file = tempfile::tempfile().unwrap(); |
| { |
| let mut writer = FileWriter::try_new_with_options(&mut file, &schema, options).unwrap(); |
| |
| writer.write(&batch).unwrap(); |
| writer.finish().unwrap(); |
| } |
| file.rewind().unwrap(); |
| |
| { |
| let reader = FileReader::try_new(file, None).unwrap(); |
| reader.for_each(|maybe_batch| { |
| maybe_batch |
| .unwrap() |
| .columns() |
| .iter() |
| .zip(batch.columns()) |
| .for_each(|(a, b)| { |
| assert_eq!(a.data_type(), b.data_type()); |
| assert_eq!(a.len(), b.len()); |
| assert_eq!(a.null_count(), b.null_count()); |
| }); |
| }); |
| } |
| } |
| |
| #[test] |
| fn test_write_union_file_v4_v5() { |
| write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap()); |
| write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()); |
| } |
| |
| #[test] |
| fn test_write_view_types() { |
| const LONG_TEST_STRING: &str = |
| "This is a long string to make sure binary view array handles it"; |
| let schema = Schema::new(vec![ |
| Field::new("field1", DataType::BinaryView, true), |
| Field::new("field2", DataType::Utf8View, true), |
| ]); |
| let values: Vec<Option<&[u8]>> = vec![ |
| Some(b"foo"), |
| Some(b"bar"), |
| Some(LONG_TEST_STRING.as_bytes()), |
| ]; |
| let binary_array = BinaryViewArray::from_iter(values); |
| let utf8_array = |
| StringViewArray::from_iter(vec![Some("foo"), Some("bar"), Some(LONG_TEST_STRING)]); |
| let record_batch = RecordBatch::try_new( |
| Arc::new(schema.clone()), |
| vec![Arc::new(binary_array), Arc::new(utf8_array)], |
| ) |
| .unwrap(); |
| |
| let mut file = tempfile::tempfile().unwrap(); |
| { |
| let mut writer = FileWriter::try_new(&mut file, &schema).unwrap(); |
| writer.write(&record_batch).unwrap(); |
| writer.finish().unwrap(); |
| } |
| file.rewind().unwrap(); |
| { |
| let mut reader = FileReader::try_new(&file, None).unwrap(); |
| let read_batch = reader.next().unwrap().unwrap(); |
| read_batch |
| .columns() |
| .iter() |
| .zip(record_batch.columns()) |
| .for_each(|(a, b)| { |
| assert_eq!(a, b); |
| }); |
| } |
| file.rewind().unwrap(); |
| { |
| let mut reader = FileReader::try_new(&file, Some(vec![0])).unwrap(); |
| let read_batch = reader.next().unwrap().unwrap(); |
| assert_eq!(read_batch.num_columns(), 1); |
| let read_array = read_batch.column(0); |
| let write_array = record_batch.column(0); |
| assert_eq!(read_array, write_array); |
| } |
| } |
| |
| #[test] |
| fn truncate_ipc_record_batch() { |
| fn create_batch(rows: usize) -> RecordBatch { |
| let schema = Schema::new(vec![ |
| Field::new("a", DataType::Int32, false), |
| Field::new("b", DataType::Utf8, false), |
| ]); |
| |
| let a = Int32Array::from_iter_values(0..rows as i32); |
| let b = StringArray::from_iter_values((0..rows).map(|i| i.to_string())); |
| |
| RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap() |
| } |
| |
| let big_record_batch = create_batch(65536); |
| |
| let length = 5; |
| let small_record_batch = create_batch(length); |
| |
| let offset = 2; |
| let record_batch_slice = big_record_batch.slice(offset, length); |
| assert!( |
| serialize_stream(&big_record_batch).len() > serialize_stream(&small_record_batch).len() |
| ); |
| assert_eq!( |
| serialize_stream(&small_record_batch).len(), |
| serialize_stream(&record_batch_slice).len() |
| ); |
| |
| assert_eq!( |
| deserialize_stream(serialize_stream(&record_batch_slice)), |
| record_batch_slice |
| ); |
| } |
| |
| #[test] |
| fn truncate_ipc_record_batch_with_nulls() { |
| fn create_batch() -> RecordBatch { |
| let schema = Schema::new(vec![ |
| Field::new("a", DataType::Int32, true), |
| Field::new("b", DataType::Utf8, true), |
| ]); |
| |
| let a = Int32Array::from(vec![Some(1), None, Some(1), None, Some(1)]); |
| let b = StringArray::from(vec![None, Some("a"), Some("a"), None, Some("a")]); |
| |
| RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap() |
| } |
| |
| let record_batch = create_batch(); |
| let record_batch_slice = record_batch.slice(1, 2); |
| let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice)); |
| |
| assert!( |
| serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len() |
| ); |
| |
| assert!(deserialized_batch.column(0).is_null(0)); |
| assert!(deserialized_batch.column(0).is_valid(1)); |
| assert!(deserialized_batch.column(1).is_valid(0)); |
| assert!(deserialized_batch.column(1).is_valid(1)); |
| |
| assert_eq!(record_batch_slice, deserialized_batch); |
| } |
| |
| #[test] |
| fn truncate_ipc_dictionary_array() { |
| fn create_batch() -> RecordBatch { |
| let values: StringArray = [Some("foo"), Some("bar"), Some("baz")] |
| .into_iter() |
| .collect(); |
| let keys: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect(); |
| |
| let array = DictionaryArray::new(keys, Arc::new(values)); |
| |
| let schema = Schema::new(vec![Field::new("dict", array.data_type().clone(), true)]); |
| |
| RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap() |
| } |
| |
| let record_batch = create_batch(); |
| let record_batch_slice = record_batch.slice(1, 2); |
| let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice)); |
| |
| assert!( |
| serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len() |
| ); |
| |
| assert!(deserialized_batch.column(0).is_valid(0)); |
| assert!(deserialized_batch.column(0).is_null(1)); |
| |
| assert_eq!(record_batch_slice, deserialized_batch); |
| } |
| |
| #[test] |
| fn truncate_ipc_struct_array() { |
| fn create_batch() -> RecordBatch { |
| let strings: StringArray = [Some("foo"), None, Some("bar"), Some("baz")] |
| .into_iter() |
| .collect(); |
| let ints: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect(); |
| |
| let struct_array = StructArray::from(vec![ |
| ( |
| Arc::new(Field::new("s", DataType::Utf8, true)), |
| Arc::new(strings) as ArrayRef, |
| ), |
| ( |
| Arc::new(Field::new("c", DataType::Int32, true)), |
| Arc::new(ints) as ArrayRef, |
| ), |
| ]); |
| |
| let schema = Schema::new(vec![Field::new( |
| "struct_array", |
| struct_array.data_type().clone(), |
| true, |
| )]); |
| |
| RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)]).unwrap() |
| } |
| |
| let record_batch = create_batch(); |
| let record_batch_slice = record_batch.slice(1, 2); |
| let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice)); |
| |
| assert!( |
| serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len() |
| ); |
| |
| let structs = deserialized_batch |
| .column(0) |
| .as_any() |
| .downcast_ref::<StructArray>() |
| .unwrap(); |
| |
| assert!(structs.column(0).is_null(0)); |
| assert!(structs.column(0).is_valid(1)); |
| assert!(structs.column(1).is_valid(0)); |
| assert!(structs.column(1).is_null(1)); |
| assert_eq!(record_batch_slice, deserialized_batch); |
| } |
| |
| #[test] |
| fn truncate_ipc_string_array_with_all_empty_string() { |
| fn create_batch() -> RecordBatch { |
| let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); |
| let a = StringArray::from(vec![Some(""), Some(""), Some(""), Some(""), Some("")]); |
| RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap() |
| } |
| |
| let record_batch = create_batch(); |
| let record_batch_slice = record_batch.slice(0, 1); |
| let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice)); |
| |
| assert!( |
| serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len() |
| ); |
| assert_eq!(record_batch_slice, deserialized_batch); |
| } |
| |
| #[test] |
| fn test_stream_writer_writes_array_slice() { |
| let array = UInt32Array::from(vec![Some(1), Some(2), Some(3)]); |
| assert_eq!( |
| vec![Some(1), Some(2), Some(3)], |
| array.iter().collect::<Vec<_>>() |
| ); |
| |
| let sliced = array.slice(1, 2); |
| assert_eq!(vec![Some(2), Some(3)], sliced.iter().collect::<Vec<_>>()); |
| |
| let batch = RecordBatch::try_new( |
| Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, true)])), |
| vec![Arc::new(sliced)], |
| ) |
| .expect("new batch"); |
| |
| let mut writer = StreamWriter::try_new(vec![], batch.schema_ref()).expect("new writer"); |
| writer.write(&batch).expect("write"); |
| let outbuf = writer.into_inner().expect("inner"); |
| |
| let mut reader = StreamReader::try_new(&outbuf[..], None).expect("new reader"); |
| let read_batch = reader.next().unwrap().expect("read batch"); |
| |
| let read_array: &UInt32Array = read_batch.column(0).as_primitive(); |
| assert_eq!( |
| vec![Some(2), Some(3)], |
| read_array.iter().collect::<Vec<_>>() |
| ); |
| } |
| |
| #[test] |
| fn encode_bools_slice() { |
| // Test case for https://github.com/apache/arrow-rs/issues/3496 |
| assert_bool_roundtrip([true, false], 1, 1); |
| |
| // slice somewhere in the middle |
| assert_bool_roundtrip( |
| [ |
| true, false, true, true, false, false, true, true, true, false, false, false, true, |
| true, true, true, false, false, false, false, true, true, true, true, true, false, |
| false, false, false, false, |
| ], |
| 13, |
| 17, |
| ); |
| |
| // start at byte boundary, end in the middle |
| assert_bool_roundtrip( |
| [ |
| true, false, true, true, false, false, true, true, true, false, false, false, |
| ], |
| 8, |
| 2, |
| ); |
| |
| // start and stop and byte boundary |
| assert_bool_roundtrip( |
| [ |
| true, false, true, true, false, false, true, true, true, false, false, false, true, |
| true, true, true, true, false, false, false, false, false, |
| ], |
| 8, |
| 8, |
| ); |
| } |
| |
| fn assert_bool_roundtrip<const N: usize>(bools: [bool; N], offset: usize, length: usize) { |
| let val_bool_field = Field::new("val", DataType::Boolean, false); |
| |
| let schema = Arc::new(Schema::new(vec![val_bool_field])); |
| |
| let bools = BooleanArray::from(bools.to_vec()); |
| |
| let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(bools)]).unwrap(); |
| let batch = batch.slice(offset, length); |
| |
| let data = serialize_stream(&batch); |
| let batch2 = deserialize_stream(data); |
| assert_eq!(batch, batch2); |
| } |
| |
| #[test] |
| fn test_run_array_unslice() { |
| let total_len = 80; |
| let vals: Vec<Option<i32>> = vec![Some(1), None, Some(2), Some(3), Some(4), None, Some(5)]; |
| let repeats: Vec<usize> = vec![3, 4, 1, 2]; |
| let mut input_array: Vec<Option<i32>> = Vec::with_capacity(total_len); |
| for ix in 0_usize..32 { |
| let repeat: usize = repeats[ix % repeats.len()]; |
| let val: Option<i32> = vals[ix % vals.len()]; |
| input_array.resize(input_array.len() + repeat, val); |
| } |
| |
| // Encode the input_array to run array |
| let mut builder = |
| PrimitiveRunBuilder::<Int16Type, Int32Type>::with_capacity(input_array.len()); |
| builder.extend(input_array.iter().copied()); |
| let run_array = builder.finish(); |
| |
| // test for all slice lengths. |
| for slice_len in 1..=total_len { |
| // test for offset = 0, slice length = slice_len |
| let sliced_run_array: RunArray<Int16Type> = |
| run_array.slice(0, slice_len).into_data().into(); |
| |
| // Create unsliced run array. |
| let unsliced_run_array = into_zero_offset_run_array(sliced_run_array).unwrap(); |
| let typed = unsliced_run_array |
| .downcast::<PrimitiveArray<Int32Type>>() |
| .unwrap(); |
| let expected: Vec<Option<i32>> = input_array.iter().take(slice_len).copied().collect(); |
| let actual: Vec<Option<i32>> = typed.into_iter().collect(); |
| assert_eq!(expected, actual); |
| |
| // test for offset = total_len - slice_len, length = slice_len |
| let sliced_run_array: RunArray<Int16Type> = run_array |
| .slice(total_len - slice_len, slice_len) |
| .into_data() |
| .into(); |
| |
| // Create unsliced run array. |
| let unsliced_run_array = into_zero_offset_run_array(sliced_run_array).unwrap(); |
| let typed = unsliced_run_array |
| .downcast::<PrimitiveArray<Int32Type>>() |
| .unwrap(); |
| let expected: Vec<Option<i32>> = input_array |
| .iter() |
| .skip(total_len - slice_len) |
| .copied() |
| .collect(); |
| let actual: Vec<Option<i32>> = typed.into_iter().collect(); |
| assert_eq!(expected, actual); |
| } |
| } |
| |
| fn generate_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> { |
| let mut ls = GenericListBuilder::<O, _>::new(UInt32Builder::new()); |
| |
| for i in 0..100_000 { |
| for value in [i, i, i] { |
| ls.values().append_value(value); |
| } |
| ls.append(true) |
| } |
| |
| ls.finish() |
| } |
| |
| fn generate_nested_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> { |
| let mut ls = |
| GenericListBuilder::<O, _>::new(GenericListBuilder::<O, _>::new(UInt32Builder::new())); |
| |
| for _i in 0..10_000 { |
| for j in 0..10 { |
| for value in [j, j, j, j] { |
| ls.values().values().append_value(value); |
| } |
| ls.values().append(true) |
| } |
| ls.append(true); |
| } |
| |
| ls.finish() |
| } |
| |
| fn generate_map_array_data() -> MapArray { |
| let keys_builder = UInt32Builder::new(); |
| let values_builder = UInt32Builder::new(); |
| |
| let mut builder = MapBuilder::new(None, keys_builder, values_builder); |
| |
| for i in 0..100_000 { |
| for _j in 0..3 { |
| builder.keys().append_value(i); |
| builder.values().append_value(i * 2); |
| } |
| builder.append(true).unwrap(); |
| } |
| |
| builder.finish() |
| } |
| |
| /// Ensure when serde full & sliced versions they are equal to original input. |
| /// Also ensure serialized sliced version is significantly smaller than serialized full. |
| fn roundtrip_ensure_sliced_smaller(in_batch: RecordBatch, expected_size_factor: usize) { |
| // test both full and sliced versions |
| let in_sliced = in_batch.slice(999, 1); |
| |
| let bytes_batch = serialize_file(&in_batch); |
| let bytes_sliced = serialize_file(&in_sliced); |
| |
| // serializing 1 row should be significantly smaller than serializing 100,000 |
| assert!(bytes_sliced.len() < (bytes_batch.len() / expected_size_factor)); |
| |
| // ensure both are still valid and equal to originals |
| let out_batch = deserialize_file(bytes_batch); |
| assert_eq!(in_batch, out_batch); |
| |
| let out_sliced = deserialize_file(bytes_sliced); |
| assert_eq!(in_sliced, out_sliced); |
| } |
| |
| #[test] |
| fn encode_lists() { |
| let val_inner = Field::new("item", DataType::UInt32, true); |
| let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false); |
| let schema = Arc::new(Schema::new(vec![val_list_field])); |
| |
| let values = Arc::new(generate_list_data::<i32>()); |
| |
| let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap(); |
| roundtrip_ensure_sliced_smaller(in_batch, 1000); |
| } |
| |
| #[test] |
| fn encode_empty_list() { |
| let val_inner = Field::new("item", DataType::UInt32, true); |
| let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false); |
| let schema = Arc::new(Schema::new(vec![val_list_field])); |
| |
| let values = Arc::new(generate_list_data::<i32>()); |
| |
| let in_batch = RecordBatch::try_new(schema, vec![values]) |
| .unwrap() |
| .slice(999, 0); |
| let out_batch = deserialize_file(serialize_file(&in_batch)); |
| assert_eq!(in_batch, out_batch); |
| } |
| |
| #[test] |
| fn encode_large_lists() { |
| let val_inner = Field::new("item", DataType::UInt32, true); |
| let val_list_field = Field::new("val", DataType::LargeList(Arc::new(val_inner)), false); |
| let schema = Arc::new(Schema::new(vec![val_list_field])); |
| |
| let values = Arc::new(generate_list_data::<i64>()); |
| |
| // ensure when serde full & sliced versions they are equal to original input |
| // also ensure serialized sliced version is significantly smaller than serialized full |
| let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap(); |
| roundtrip_ensure_sliced_smaller(in_batch, 1000); |
| } |
| |
| #[test] |
| fn encode_nested_lists() { |
| let inner_int = Arc::new(Field::new("item", DataType::UInt32, true)); |
| let inner_list_field = Arc::new(Field::new("item", DataType::List(inner_int), true)); |
| let list_field = Field::new("val", DataType::List(inner_list_field), true); |
| let schema = Arc::new(Schema::new(vec![list_field])); |
| |
| let values = Arc::new(generate_nested_list_data::<i32>()); |
| |
| let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap(); |
| roundtrip_ensure_sliced_smaller(in_batch, 1000); |
| } |
| |
| #[test] |
| fn encode_map_array() { |
| let keys = Arc::new(Field::new("keys", DataType::UInt32, false)); |
| let values = Arc::new(Field::new("values", DataType::UInt32, true)); |
| let map_field = Field::new_map("map", "entries", keys, values, false, true); |
| let schema = Arc::new(Schema::new(vec![map_field])); |
| |
| let values = Arc::new(generate_map_array_data()); |
| |
| let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap(); |
| roundtrip_ensure_sliced_smaller(in_batch, 1000); |
| } |
| |
| #[test] |
| fn test_decimal128_alignment16_is_sufficient() { |
| const IPC_ALIGNMENT: usize = 16; |
| |
| // Test a bunch of different dimensions to ensure alignment is never an issue. |
| // For example, if we only test `num_cols = 1` then even with alignment 8 this |
| // test would _happen_ to pass, even though for different dimensions like |
| // `num_cols = 2` it would fail. |
| for num_cols in [1, 2, 3, 17, 50, 73, 99] { |
| let num_rows = (num_cols * 7 + 11) % 100; // Deterministic swizzle |
| |
| let mut fields = Vec::new(); |
| let mut arrays = Vec::new(); |
| for i in 0..num_cols { |
| let field = Field::new(&format!("col_{}", i), DataType::Decimal128(38, 10), true); |
| let array = Decimal128Array::from(vec![num_cols as i128; num_rows]); |
| fields.push(field); |
| arrays.push(Arc::new(array) as Arc<dyn Array>); |
| } |
| let schema = Schema::new(fields); |
| let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap(); |
| |
| let mut writer = FileWriter::try_new_with_options( |
| Vec::new(), |
| batch.schema_ref(), |
| IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(), |
| ) |
| .unwrap(); |
| writer.write(&batch).unwrap(); |
| writer.finish().unwrap(); |
| |
| let out: Vec<u8> = writer.into_inner().unwrap(); |
| |
| let buffer = Buffer::from_vec(out); |
| let trailer_start = buffer.len() - 10; |
| let footer_len = |
| read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap(); |
| let footer = |
| root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap(); |
| |
| let schema = fb_to_schema(footer.schema().unwrap()); |
| |
| // Importantly we set `require_alignment`, checking that 16-byte alignment is sufficient |
| // for `read_record_batch` later on to read the data in a zero-copy manner. |
| let decoder = |
| FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true); |
| |
| let batches = footer.recordBatches().unwrap(); |
| |
| let block = batches.get(0); |
| let block_len = block.bodyLength() as usize + block.metaDataLength() as usize; |
| let data = buffer.slice_with_length(block.offset() as _, block_len); |
| |
| let batch2 = decoder.read_record_batch(block, &data).unwrap().unwrap(); |
| |
| assert_eq!(batch, batch2); |
| } |
| } |
| |
| #[test] |
| fn test_decimal128_alignment8_is_unaligned() { |
| const IPC_ALIGNMENT: usize = 8; |
| |
| let num_cols = 2; |
| let num_rows = 1; |
| |
| let mut fields = Vec::new(); |
| let mut arrays = Vec::new(); |
| for i in 0..num_cols { |
| let field = Field::new(&format!("col_{}", i), DataType::Decimal128(38, 10), true); |
| let array = Decimal128Array::from(vec![num_cols as i128; num_rows]); |
| fields.push(field); |
| arrays.push(Arc::new(array) as Arc<dyn Array>); |
| } |
| let schema = Schema::new(fields); |
| let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap(); |
| |
| let mut writer = FileWriter::try_new_with_options( |
| Vec::new(), |
| batch.schema_ref(), |
| IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(), |
| ) |
| .unwrap(); |
| writer.write(&batch).unwrap(); |
| writer.finish().unwrap(); |
| |
| let out: Vec<u8> = writer.into_inner().unwrap(); |
| |
| let buffer = Buffer::from_vec(out); |
| let trailer_start = buffer.len() - 10; |
| let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap(); |
| let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap(); |
| |
| let schema = fb_to_schema(footer.schema().unwrap()); |
| |
| // Importantly we set `require_alignment`, otherwise the error later is suppressed due to copying |
| // to an aligned buffer in `ArrayDataBuilder.build_aligned`. |
| let decoder = |
| FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true); |
| |
| let batches = footer.recordBatches().unwrap(); |
| |
| let block = batches.get(0); |
| let block_len = block.bodyLength() as usize + block.metaDataLength() as usize; |
| let data = buffer.slice_with_length(block.offset() as _, block_len); |
| |
| let result = decoder.read_record_batch(block, &data); |
| |
| let error = result.unwrap_err(); |
| assert_eq!( |
| error.to_string(), |
| "Invalid argument error: Misaligned buffers[0] in array of type Decimal128(38, 10), \ |
| offset from expected alignment of 16 by 8" |
| ); |
| } |
| } |