blob: 9bbc6e752c1299c1fa397bed398c0cc1b0e5be1f [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use crate::CompressionType;
use arrow_buffer::Buffer;
use arrow_schema::ArrowError;
const LENGTH_NO_COMPRESSED_DATA: i64 = -1;
const LENGTH_OF_PREFIX_DATA: i64 = 8;
/// Additional context that may be needed for compression.
///
/// In the case of zstd, this will contain the zstd context, which can be reused between subsequent
/// compression calls to avoid the performance overhead of initialising a new context for every
/// compression.
pub struct CompressionContext {
#[cfg(feature = "zstd")]
compressor: zstd::bulk::Compressor<'static>,
}
// the reason we allow derivable_impls here is because when zstd feature is not enabled, this
// becomes derivable. however with zstd feature want to be explicit about the compression level.
#[allow(clippy::derivable_impls)]
impl Default for CompressionContext {
fn default() -> Self {
CompressionContext {
// safety: `new` here will only return error here if using an invalid compression level
#[cfg(feature = "zstd")]
compressor: zstd::bulk::Compressor::new(zstd::DEFAULT_COMPRESSION_LEVEL)
.expect("can use default compression level"),
}
}
}
impl std::fmt::Debug for CompressionContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut ds = f.debug_struct("CompressionContext");
#[cfg(feature = "zstd")]
ds.field("compressor", &"zstd::bulk::Compressor");
ds.finish()
}
}
/// Represents compressing a ipc stream using a particular compression algorithm
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionCodec {
Lz4Frame,
Zstd,
}
impl TryFrom<CompressionType> for CompressionCodec {
type Error = ArrowError;
fn try_from(compression_type: CompressionType) -> Result<Self, ArrowError> {
match compression_type {
CompressionType::ZSTD => Ok(CompressionCodec::Zstd),
CompressionType::LZ4_FRAME => Ok(CompressionCodec::Lz4Frame),
other_type => Err(ArrowError::NotYetImplemented(format!(
"compression type {other_type:?} not supported "
))),
}
}
}
impl CompressionCodec {
/// Compresses the data in `input` to `output` and appends the
/// data using the specified compression mechanism.
///
/// returns the number of bytes written to the stream
///
/// Writes this format to output:
/// ```text
/// [8 bytes]: uncompressed length
/// [remaining bytes]: compressed data stream
/// ```
pub(crate) fn compress_to_vec(
&self,
input: &[u8],
output: &mut Vec<u8>,
context: &mut CompressionContext,
) -> Result<usize, ArrowError> {
let uncompressed_data_len = input.len();
let original_output_len = output.len();
if input.is_empty() {
// empty input, nothing to do
} else {
// write compressed data directly into the output buffer
output.extend_from_slice(&uncompressed_data_len.to_le_bytes());
self.compress(input, output, context)?;
let compression_len = output.len() - original_output_len;
if compression_len > uncompressed_data_len {
// length of compressed data was larger than
// uncompressed data, use the uncompressed data with
// length -1 to indicate that we don't compress the
// data
output.truncate(original_output_len);
output.extend_from_slice(&LENGTH_NO_COMPRESSED_DATA.to_le_bytes());
output.extend_from_slice(input);
}
}
Ok(output.len() - original_output_len)
}
/// Decompresses the input into a [`Buffer`]
///
/// The input should look like:
/// ```text
/// [8 bytes]: uncompressed length
/// [remaining bytes]: compressed data stream
/// ```
pub(crate) fn decompress_to_buffer(&self, input: &Buffer) -> Result<Buffer, ArrowError> {
// read the first 8 bytes to determine if the data is
// compressed
let decompressed_length = read_uncompressed_size(input);
let buffer = if decompressed_length == 0 {
// empty
Buffer::from([])
} else if decompressed_length == LENGTH_NO_COMPRESSED_DATA {
// no compression
input.slice(LENGTH_OF_PREFIX_DATA as usize)
} else if let Ok(decompressed_length) = usize::try_from(decompressed_length) {
// decompress data using the codec
let input_data = &input[(LENGTH_OF_PREFIX_DATA as usize)..];
let v = self.decompress(input_data, decompressed_length as _)?;
Buffer::from_vec(v)
} else {
return Err(ArrowError::IpcError(format!(
"Invalid uncompressed length: {decompressed_length}"
)));
};
Ok(buffer)
}
/// Compress the data in input buffer and write to output buffer
/// using the specified compression
fn compress(
&self,
input: &[u8],
output: &mut Vec<u8>,
context: &mut CompressionContext,
) -> Result<(), ArrowError> {
match self {
CompressionCodec::Lz4Frame => compress_lz4(input, output),
CompressionCodec::Zstd => compress_zstd(input, output, context),
}
}
/// Decompress the data in input buffer and write to output buffer
/// using the specified compression
fn decompress(&self, input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
let ret = match self {
CompressionCodec::Lz4Frame => decompress_lz4(input, decompressed_size)?,
CompressionCodec::Zstd => decompress_zstd(input, decompressed_size)?,
};
if ret.len() != decompressed_size {
return Err(ArrowError::IpcError(format!(
"Expected compressed length of {decompressed_size} got {}",
ret.len()
)));
}
Ok(ret)
}
}
#[cfg(feature = "lz4")]
fn compress_lz4(input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
use std::io::Write;
let mut encoder = lz4_flex::frame::FrameEncoder::new(output);
encoder.write_all(input)?;
encoder
.finish()
.map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
Ok(())
}
#[cfg(not(feature = "lz4"))]
#[allow(clippy::ptr_arg)]
fn compress_lz4(_input: &[u8], _output: &mut Vec<u8>) -> Result<(), ArrowError> {
Err(ArrowError::InvalidArgumentError(
"lz4 IPC compression requires the lz4 feature".to_string(),
))
}
#[cfg(feature = "lz4")]
fn decompress_lz4(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
use std::io::Read;
let mut output = Vec::with_capacity(decompressed_size);
lz4_flex::frame::FrameDecoder::new(input).read_to_end(&mut output)?;
Ok(output)
}
#[cfg(not(feature = "lz4"))]
#[allow(clippy::ptr_arg)]
fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
Err(ArrowError::InvalidArgumentError(
"lz4 IPC decompression requires the lz4 feature".to_string(),
))
}
#[cfg(feature = "zstd")]
fn compress_zstd(
input: &[u8],
output: &mut Vec<u8>,
context: &mut CompressionContext,
) -> Result<(), ArrowError> {
let result = context.compressor.compress(input)?;
output.extend_from_slice(&result);
Ok(())
}
#[cfg(not(feature = "zstd"))]
#[allow(clippy::ptr_arg)]
fn compress_zstd(
_input: &[u8],
_output: &mut Vec<u8>,
_context: &mut CompressionContext,
) -> Result<(), ArrowError> {
Err(ArrowError::InvalidArgumentError(
"zstd IPC compression requires the zstd feature".to_string(),
))
}
#[cfg(feature = "zstd")]
fn decompress_zstd(input: &[u8], decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
use std::io::Read;
let mut output = Vec::with_capacity(decompressed_size);
zstd::Decoder::with_buffer(input)?.read_to_end(&mut output)?;
Ok(output)
}
#[cfg(not(feature = "zstd"))]
#[allow(clippy::ptr_arg)]
fn decompress_zstd(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, ArrowError> {
Err(ArrowError::InvalidArgumentError(
"zstd IPC decompression requires the zstd feature".to_string(),
))
}
/// Get the uncompressed length
/// Notes:
/// LENGTH_NO_COMPRESSED_DATA: indicate that the data that follows is not compressed
/// 0: indicate that there is no data
/// positive number: indicate the uncompressed length for the following data
#[inline]
fn read_uncompressed_size(buffer: &[u8]) -> i64 {
let len_buffer = &buffer[0..8];
// 64-bit little-endian signed integer
i64::from_le_bytes(len_buffer.try_into().unwrap())
}
#[cfg(test)]
mod tests {
#[test]
#[cfg(feature = "lz4")]
fn test_lz4_compression() {
let input_bytes = b"hello lz4";
let codec = super::CompressionCodec::Lz4Frame;
let mut output_bytes: Vec<u8> = Vec::new();
codec
.compress(input_bytes, &mut output_bytes, &mut Default::default())
.unwrap();
let result = codec
.decompress(output_bytes.as_slice(), input_bytes.len())
.unwrap();
assert_eq!(input_bytes, result.as_slice());
}
#[test]
#[cfg(feature = "zstd")]
fn test_zstd_compression() {
let input_bytes = b"hello zstd";
let codec = super::CompressionCodec::Zstd;
let mut output_bytes: Vec<u8> = Vec::new();
codec
.compress(input_bytes, &mut output_bytes, &mut Default::default())
.unwrap();
let result = codec
.decompress(output_bytes.as_slice(), input_bytes.len())
.unwrap();
assert_eq!(input_bytes, result.as_slice());
}
}