blob: 0ee8e47bede06bc8b464add8c418ffff39cfd69a [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use crate::array::*;
use crate::compute::SlicesIterator;
use crate::error::{ArrowError, Result};
/// Zip two arrays by some boolean mask. Where the mask evaluates `true` values of `truthy`
/// are taken, where the mask evaluates `false` values of `falsy` are taken.
///
/// # Arguments
/// * `mask` - Boolean values used to determine from which array to take the values.
/// * `truthy` - Values of this array are taken if mask evaluates `true`
/// * `falsy` - Values of this array are taken if mask evaluates `false`
pub fn zip(
mask: &BooleanArray,
truthy: &dyn Array,
falsy: &dyn Array,
) -> Result<ArrayRef> {
if truthy.data_type() != falsy.data_type() {
return Err(ArrowError::InvalidArgumentError(
"arguments need to have the same data type".into(),
));
}
if truthy.len() != falsy.len() || falsy.len() != mask.len() {
return Err(ArrowError::InvalidArgumentError(
"all arrays should have the same length".into(),
));
}
let falsy = falsy.data();
let truthy = truthy.data();
let mut mutable = MutableArrayData::new(vec![&*truthy, &*falsy], false, truthy.len());
// the SlicesIterator slices only the true values. So the gaps left by this iterator we need to
// fill with falsy values
// keep track of how much is filled
let mut filled = 0;
SlicesIterator::new(mask).for_each(|(start, end)| {
// the gap needs to be filled with falsy values
if start > filled {
mutable.extend(1, filled, start);
}
// fill with truthy values
mutable.extend(0, start, end);
filled = end;
});
// the remaining part is falsy
if filled < truthy.len() {
mutable.extend(1, filled, truthy.len());
}
let data = mutable.freeze();
Ok(make_array(data))
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_zip_kernel() {
let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
let b = Int32Array::from(vec![None, Some(3), Some(6), Some(7), Some(3)]);
let mask = BooleanArray::from(vec![true, true, false, false, true]);
let out = zip(&mask, &a, &b).unwrap();
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
let expected = Int32Array::from(vec![Some(5), None, Some(6), Some(7), Some(1)]);
assert_eq!(actual, &expected);
}
}