| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| use crate::CompressionType; |
| use arrow_buffer::Buffer; |
| use arrow_schema::ArrowError; |
| |
| const LENGTH_NO_COMPRESSED_DATA: i64 = -1; |
| const LENGTH_OF_PREFIX_DATA: i64 = 8; |
| |
| /// Additional context that may be needed for compression. |
| /// |
| /// In the case of zstd, this will contain the zstd context, which can be reused between subsequent |
| /// compression calls to avoid the performance overhead of initialising a new context for every |
| /// compression. |
| pub struct CompressionContext { |
| #[cfg(feature = "zstd")] |
| compressor: zstd::bulk::Compressor<'static>, |
| } |
| |
| // the reason we allow derivable_impls here is because when zstd feature is not enabled, this |
| // becomes derivable. however with zstd feature want to be explicit about the compression level. |
| #[allow(clippy::derivable_impls)] |
| impl Default for CompressionContext { |
| fn default() -> Self { |
| CompressionContext { |
| // safety: `new` here will only return error here if using an invalid compression level |
| #[cfg(feature = "zstd")] |
| compressor: zstd::bulk::Compressor::new(zstd::DEFAULT_COMPRESSION_LEVEL) |
| .expect("can use default compression level"), |
| } |
| } |
| } |
| |
| impl std::fmt::Debug for CompressionContext { |
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| let mut ds = f.debug_struct("CompressionContext"); |
| |
| #[cfg(feature = "zstd")] |
| ds.field("compressor", &"zstd::bulk::Compressor"); |
| |
| ds.finish() |
| } |
| } |
| |
| /// Represents compressing a ipc stream using a particular compression algorithm |
| #[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| pub enum CompressionCodec { |
| Lz4Frame, |
| Zstd, |
| } |
| |
| impl TryFrom<CompressionType> for CompressionCodec { |
| type Error = ArrowError; |
| |
| fn try_from(compression_type: CompressionType) -> Result<Self, ArrowError> { |
| match compression_type { |
| CompressionType::ZSTD => Ok(CompressionCodec::Zstd), |
| CompressionType::LZ4_FRAME => Ok(CompressionCodec::Lz4Frame), |
| other_type => Err(ArrowError::NotYetImplemented(format!( |
| "compression type {other_type:?} not supported " |
| ))), |
| } |
| } |
| } |
| |
| impl CompressionCodec { |
| /// Compresses the data in `input` to `output` and appends the |
| /// data using the specified compression mechanism. |
| /// |
| /// returns the number of bytes written to the stream |
| /// |
| /// Writes this format to output: |
| /// ```text |
| /// [8 bytes]: uncompressed length |
| /// [remaining bytes]: compressed data stream |
| /// ``` |
| pub(crate) fn compress_to_vec( |
| &self, |
| input: &[u8], |
| output: &mut Vec<u8>, |
| context: &mut CompressionContext, |
| ) -> Result<usize, ArrowError> { |
| let uncompressed_data_len = input.len(); |
| let original_output_len = output.len(); |
| |
| if input.is_empty() { |
| // empty input, nothing to do |
| } else { |
| // write compressed data directly into the output buffer |
| output.extend_from_slice(&uncompressed_data_len.to_le_bytes()); |
| self.compress(input, output, context)?; |
| |
| let compression_len = output.len() - original_output_len; |
| if compression_len > uncompressed_data_len { |
| // length of compressed data was larger than |
| // uncompressed data, use the uncompressed data with |
| // length -1 to indicate that we don't compress the |
| // data |
| output.truncate(original_output_len); |
| output.extend_from_slice(&LENGTH_NO_COMPRESSED_DATA.to_le_bytes()); |
| output.extend_from_slice(input); |
| } |
| } |
| Ok(output.len() - original_output_len) |
| } |
| |
| /// Decompresses the input into a [`Buffer`] |
| /// |
| /// The input should look like: |
| /// ```text |
| /// [8 bytes]: uncompressed length |
| /// [remaining bytes]: compressed data stream |
| /// ``` |
| pub(crate) fn decompress_to_buffer(&self, input: &Buffer) -> Result<Buffer, ArrowError> { |
| // read the first 8 bytes to determine if the data is |
| // compressed |
| let decompressed_length = read_uncompressed_size(input); |
| let buffer = if decompressed_length == 0 { |
| // empty |
| Buffer::from([]) |
| } else if decompressed_length == LENGTH_NO_COMPRESSED_DATA { |
| // no compression |
| input.slice(LENGTH_OF_PREFIX_DATA as usize) |
| } else if let Ok(decompressed_length) = usize::try_from(decompressed_length) { |
| // decompress data using the codec |
| let input_data = &input[(LENGTH_OF_PREFIX_DATA as usize)..]; |
| let v = self.decompress(input_data, decompressed_length as _)?; |
| Buffer::from_vec(v) |
| } else { |
| return Err(ArrowError::IpcError(format!( |
| "Invalid uncompressed length: {decompressed_length}" |
| ))); |
| }; |
| Ok(buffer) |
| } |
| |
| /// Compress the data in input buffer and write to output buffer |
| /// using the specified compression |
| fn compress( |
| &self, |
| input: &[u8], |
| output: &mut Vec<u8>, |
| context: &mut CompressionContext, |
| ) -> Result<(), ArrowError> { |
| match self { |
| CompressionCodec::Lz4Frame => compress_lz4(input, output), |
| CompressionCodec::Zstd => compress_zstd(input, output, context), |
| } |
| } |
| |
| /// Decompress the data in input buffer and write to output buffer |
| /// using the specified compression |
| fn decompress(&self, input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> { |
| let ret = match self { |
| CompressionCodec::Lz4Frame => decompress_lz4(input, decompressed_size)?, |
| CompressionCodec::Zstd => decompress_zstd(input, decompressed_size)?, |
| }; |
| if ret.len() != decompressed_size { |
| return Err(ArrowError::IpcError(format!( |
| "Expected compressed length of {decompressed_size} got {}", |
| ret.len() |
| ))); |
| } |
| Ok(ret) |
| } |
| } |
| |
| #[cfg(feature = "lz4")] |
| fn compress_lz4(input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> { |
| use std::io::Write; |
| let mut encoder = lz4_flex::frame::FrameEncoder::new(output); |
| encoder.write_all(input)?; |
| encoder |
| .finish() |
| .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; |
| Ok(()) |
| } |
| |
| #[cfg(not(feature = "lz4"))] |
| #[allow(clippy::ptr_arg)] |
| fn compress_lz4(_input: &[u8], _output: &mut Vec<u8>) -> Result<(), ArrowError> { |
| Err(ArrowError::InvalidArgumentError( |
| "lz4 IPC compression requires the lz4 feature".to_string(), |
| )) |
| } |
| |
| #[cfg(feature = "lz4")] |
| fn decompress_lz4(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> { |
| use std::io::Read; |
| let mut output = Vec::with_capacity(decompressed_size); |
| lz4_flex::frame::FrameDecoder::new(input).read_to_end(&mut output)?; |
| Ok(output) |
| } |
| |
| #[cfg(not(feature = "lz4"))] |
| #[allow(clippy::ptr_arg)] |
| fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> { |
| Err(ArrowError::InvalidArgumentError( |
| "lz4 IPC decompression requires the lz4 feature".to_string(), |
| )) |
| } |
| |
| #[cfg(feature = "zstd")] |
| fn compress_zstd( |
| input: &[u8], |
| output: &mut Vec<u8>, |
| context: &mut CompressionContext, |
| ) -> Result<(), ArrowError> { |
| let result = context.compressor.compress(input)?; |
| output.extend_from_slice(&result); |
| Ok(()) |
| } |
| |
| #[cfg(not(feature = "zstd"))] |
| #[allow(clippy::ptr_arg)] |
| fn compress_zstd( |
| _input: &[u8], |
| _output: &mut Vec<u8>, |
| _context: &mut CompressionContext, |
| ) -> Result<(), ArrowError> { |
| Err(ArrowError::InvalidArgumentError( |
| "zstd IPC compression requires the zstd feature".to_string(), |
| )) |
| } |
| |
| #[cfg(feature = "zstd")] |
| fn decompress_zstd(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> { |
| use std::io::Read; |
| let mut output = Vec::with_capacity(decompressed_size); |
| zstd::Decoder::with_buffer(input)?.read_to_end(&mut output)?; |
| Ok(output) |
| } |
| |
| #[cfg(not(feature = "zstd"))] |
| #[allow(clippy::ptr_arg)] |
| fn decompress_zstd(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> { |
| Err(ArrowError::InvalidArgumentError( |
| "zstd IPC decompression requires the zstd feature".to_string(), |
| )) |
| } |
| |
| /// Get the uncompressed length |
| /// Notes: |
| /// LENGTH_NO_COMPRESSED_DATA: indicate that the data that follows is not compressed |
| /// 0: indicate that there is no data |
| /// positive number: indicate the uncompressed length for the following data |
| #[inline] |
| fn read_uncompressed_size(buffer: &[u8]) -> i64 { |
| let len_buffer = &buffer[0..8]; |
| // 64-bit little-endian signed integer |
| i64::from_le_bytes(len_buffer.try_into().unwrap()) |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| #[test] |
| #[cfg(feature = "lz4")] |
| fn test_lz4_compression() { |
| let input_bytes = b"hello lz4"; |
| let codec = super::CompressionCodec::Lz4Frame; |
| let mut output_bytes: Vec<u8> = Vec::new(); |
| codec |
| .compress(input_bytes, &mut output_bytes, &mut Default::default()) |
| .unwrap(); |
| let result = codec |
| .decompress(output_bytes.as_slice(), input_bytes.len()) |
| .unwrap(); |
| assert_eq!(input_bytes, result.as_slice()); |
| } |
| |
| #[test] |
| #[cfg(feature = "zstd")] |
| fn test_zstd_compression() { |
| let input_bytes = b"hello zstd"; |
| let codec = super::CompressionCodec::Zstd; |
| let mut output_bytes: Vec<u8> = Vec::new(); |
| codec |
| .compress(input_bytes, &mut output_bytes, &mut Default::default()) |
| .unwrap(); |
| let result = codec |
| .decompress(output_bytes.as_slice(), input_bytes.len()) |
| .unwrap(); |
| assert_eq!(input_bytes, result.as_slice()); |
| } |
| } |