respect offset in utf8 and list casts (#335)
diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs
index 442a6d6..463da7c 100644
--- a/arrow/src/compute/kernels/cast.rs
+++ b/arrow/src/compute/kernels/cast.rs
@@ -1687,6 +1687,7 @@
};
let mut builder = ArrayData::builder(dtype)
+ .offset(array.offset())
.len(array.len())
.add_buffer(offset_buffer)
.add_buffer(str_values_buf);
@@ -1744,7 +1745,12 @@
_ => unreachable!(),
};
- let offsets = data.buffer::<OffsetSizeFrom>(0);
+ // Safety:
+ // The first buffer is the offsets and they are aligned to OffSetSizeFrom: (i64 or i32)
+ // Justification:
+ // The safe variant data.buffer::<OffsetSizeFrom> take the offset into account and we
+ // cannot create a list array with offsets starting at non zero.
+ let offsets = unsafe { data.buffers()[0].as_slice().align_to::<OffsetSizeFrom>() }.1;
let iter = offsets.iter().map(|idx| {
let idx: OffsetSizeTo = NumCast::from(*idx).unwrap();
@@ -1757,6 +1763,7 @@
// wrap up
let mut builder = ArrayData::builder(out_dtype)
+ .offset(array.offset())
.len(array.len())
.add_buffer(offset_buffer)
.add_child_data(value_data);
@@ -3841,4 +3848,30 @@
Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
]
}
+
+ #[test]
+ fn test_utf8_cast_offsets() {
+ // test if offset of the array is taken into account during cast
+ let str_array = StringArray::from(vec!["a", "b", "c"]);
+ let str_array = str_array.slice(1, 2);
+
+ let out = cast(&str_array, &DataType::LargeUtf8).unwrap();
+
+ let large_str_array = out.as_any().downcast_ref::<LargeStringArray>().unwrap();
+ let strs = large_str_array.into_iter().flatten().collect::<Vec<_>>();
+ assert_eq!(strs, &["b", "c"])
+ }
+
+ #[test]
+ fn test_list_cast_offsets() {
+ // test if offset of the array is taken into account during cast
+ let array1 = make_list_array().slice(1, 2);
+ let array2 = Arc::new(make_list_array()) as ArrayRef;
+
+ let dt = DataType::LargeList(Box::new(Field::new("item", DataType::Int32, true)));
+ let out1 = cast(&array1, &dt).unwrap();
+ let out2 = cast(&array2, &dt).unwrap();
+
+ assert_eq!(&out1, &out2.slice(1, 2))
+ }
}