ARROW-11806: [Rust][DataFusion] Optimize join / inner join creation of indices
This PR implements two optimizations
* Change the way we create an array of indices for an inner join to avoid generating a null bit map. It seems currently not really ergonomic to do this with Arrow without resorting to an iterator (which would be hard to do here). This is around 3% difference
* Allow to reuse allocations in `create_hashes` when possible. This is around 2% faster.
In total this gives a small (5%) speedup to query 5:
This PR:
```
Query 5 iteration 0 took 169.3 ms
Query 5 iteration 1 took 156.0 ms
Query 5 iteration 2 took 157.5 ms
Query 5 iteration 3 took 158.0 ms
Query 5 iteration 4 took 157.3 ms
Query 5 iteration 5 took 163.4 ms
Query 5 iteration 6 took 167.6 ms
Query 5 iteration 7 took 171.5 ms
Query 5 iteration 8 took 167.4 ms
Query 5 iteration 9 took 164.5 ms
Query 5 avg time: 163.26 ms
```
Master:
```
Query 5 iteration 0 took 177.6 ms
Query 5 iteration 1 took 169.6 ms
Query 5 iteration 2 took 171.8 ms
Query 5 iteration 3 took 175.1 ms
Query 5 iteration 4 took 167.2 ms
Query 5 iteration 5 took 171.1 ms
Query 5 iteration 6 took 174.2 ms
Query 5 iteration 7 took 178.1 ms
Query 5 iteration 8 took 167.9 ms
Query 5 iteration 9 took 172.0 ms
Query 5 avg time: 172.46 ms
```
Closes #9595 from Dandandan/opt_hash_join
Authored-by: Heres, Daniel <danielheres@gmail.com>
Signed-off-by: Andrew Lamb <andrew@nerdnetworks.org>
diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs
index 25630a9..7ca769a 100644
--- a/rust/datafusion/src/physical_plan/hash_join.rs
+++ b/rust/datafusion/src/physical_plan/hash_join.rs
@@ -23,11 +23,12 @@
use arrow::{
array::{
- ArrayRef, BooleanArray, LargeStringArray, TimestampMicrosecondArray,
- TimestampNanosecondArray, UInt32Builder, UInt64Builder,
+ ArrayData, ArrayRef, BooleanArray, LargeStringArray, PrimitiveArray,
+ TimestampMicrosecondArray, TimestampNanosecondArray, UInt32BufferBuilder,
+ UInt32Builder, UInt64BufferBuilder, UInt64Builder,
},
compute,
- datatypes::TimeUnit,
+ datatypes::{TimeUnit, UInt32Type, UInt64Type},
};
use std::time::Instant;
use std::{any::Any, collections::HashSet};
@@ -237,19 +238,26 @@
// This operation performs 2 steps at once:
// 1. creates a [JoinHashMap] of all batches from the stream
// 2. stores the batches in a vector.
- let initial =
- (JoinHashMap::with_hasher(IdHashBuilder {}), Vec::new(), 0);
- let (hashmap, batches, num_rows) = stream
+ let initial = (
+ JoinHashMap::with_hasher(IdHashBuilder {}),
+ Vec::new(),
+ 0,
+ Vec::new(),
+ );
+ let (hashmap, batches, num_rows, _) = stream
.try_fold(initial, |mut acc, batch| async {
let hash = &mut acc.0;
let values = &mut acc.1;
let offset = acc.2;
+ acc.3.clear();
+ acc.3.resize(batch.num_rows(), 0);
update_hash(
&on_left,
&batch,
hash,
offset,
&self.random_state,
+ &mut acc.3,
)
.unwrap();
acc.2 += batch.num_rows();
@@ -311,6 +319,7 @@
hash: &mut JoinHashMap,
offset: usize,
random_state: &RandomState,
+ hashes_buffer: &mut Vec<u64>,
) -> Result<()> {
// evaluate the keys
let keys_values = on
@@ -319,7 +328,7 @@
.collect::<Result<Vec<_>>>()?;
// update the hash map
- let hash_values = create_hashes(&keys_values, &random_state)?;
+ let hash_values = create_hashes(&keys_values, &random_state, hashes_buffer)?;
// insert hashes to key of the hashmap
for (row, hash_value) in hash_values.iter().enumerate() {
@@ -476,15 +485,16 @@
.into_array(left_data.1.num_rows()))
})
.collect::<Result<Vec<_>>>()?;
-
- let hash_values = create_hashes(&keys_values, &random_state)?;
+ let hashes_buffer = &mut vec![0; keys_values[0].len()];
+ let hash_values = create_hashes(&keys_values, &random_state, hashes_buffer)?;
let left = &left_data.0;
- let mut left_indices = UInt64Builder::new(0);
- let mut right_indices = UInt32Builder::new(0);
-
match join_type {
JoinType::Inner => {
+ // Using a buffer builder to avoid slower normal builder
+ let mut left_indices = UInt64BufferBuilder::new(0);
+ let mut right_indices = UInt32BufferBuilder::new(0);
+
// Visit all of the right rows
for (row, hash_value) in hash_values.iter().enumerate() {
// Get the hash and find it in the build index
@@ -496,15 +506,30 @@
for &i in indices {
// Check hash collisions
if equal_rows(i as usize, row, &left_join_values, &keys_values)? {
- left_indices.append_value(i)?;
- right_indices.append_value(row as u32)?;
+ left_indices.append(i);
+ right_indices.append(row as u32);
}
}
}
}
- Ok((left_indices.finish(), right_indices.finish()))
+ let left = ArrayData::builder(DataType::UInt64)
+ .len(left_indices.len())
+ .add_buffer(left_indices.finish())
+ .build();
+ let right = ArrayData::builder(DataType::UInt32)
+ .len(right_indices.len())
+ .add_buffer(right_indices.finish())
+ .build();
+
+ Ok((
+ PrimitiveArray::<UInt64Type>::from(left),
+ PrimitiveArray::<UInt32Type>::from(right),
+ ))
}
JoinType::Left => {
+ let mut left_indices = UInt64Builder::new(0);
+ let mut right_indices = UInt32Builder::new(0);
+
// Keep track of which item is visited in the build input
// TODO: this can be stored more efficiently with a marker
// https://issues.apache.org/jira/browse/ARROW-11116
@@ -534,10 +559,12 @@
}
}
}
-
Ok((left_indices.finish(), right_indices.finish()))
}
JoinType::Right => {
+ let mut left_indices = UInt64Builder::new(0);
+ let mut right_indices = UInt32Builder::new(0);
+
for (row, hash_value) in hash_values.iter().enumerate() {
match left.get(hash_value) {
Some(indices) => {
@@ -699,50 +726,60 @@
}
/// Creates hash values for every element in the row based on the values in the columns
-pub fn create_hashes(
+pub fn create_hashes<'a>(
arrays: &[ArrayRef],
random_state: &RandomState,
-) -> Result<Vec<u64>> {
- let rows = arrays[0].len();
- let mut hashes = vec![0; rows];
-
+ hashes_buffer: &'a mut Vec<u64>,
+) -> Result<&'a mut Vec<u64>> {
for col in arrays {
match col.data_type() {
DataType::UInt8 => {
- hash_array!(UInt8Array, col, u8, hashes, random_state);
+ hash_array!(UInt8Array, col, u8, hashes_buffer, random_state);
}
DataType::UInt16 => {
- hash_array!(UInt16Array, col, u16, hashes, random_state);
+ hash_array!(UInt16Array, col, u16, hashes_buffer, random_state);
}
DataType::UInt32 => {
- hash_array!(UInt32Array, col, u32, hashes, random_state);
+ hash_array!(UInt32Array, col, u32, hashes_buffer, random_state);
}
DataType::UInt64 => {
- hash_array!(UInt64Array, col, u64, hashes, random_state);
+ hash_array!(UInt64Array, col, u64, hashes_buffer, random_state);
}
DataType::Int8 => {
- hash_array!(Int8Array, col, i8, hashes, random_state);
+ hash_array!(Int8Array, col, i8, hashes_buffer, random_state);
}
DataType::Int16 => {
- hash_array!(Int16Array, col, i16, hashes, random_state);
+ hash_array!(Int16Array, col, i16, hashes_buffer, random_state);
}
DataType::Int32 => {
- hash_array!(Int32Array, col, i32, hashes, random_state);
+ hash_array!(Int32Array, col, i32, hashes_buffer, random_state);
}
DataType::Int64 => {
- hash_array!(Int64Array, col, i64, hashes, random_state);
+ hash_array!(Int64Array, col, i64, hashes_buffer, random_state);
}
DataType::Timestamp(TimeUnit::Microsecond, None) => {
- hash_array!(TimestampMicrosecondArray, col, i64, hashes, random_state);
+ hash_array!(
+ TimestampMicrosecondArray,
+ col,
+ i64,
+ hashes_buffer,
+ random_state
+ );
}
DataType::Timestamp(TimeUnit::Nanosecond, None) => {
- hash_array!(TimestampNanosecondArray, col, i64, hashes, random_state);
+ hash_array!(
+ TimestampNanosecondArray,
+ col,
+ i64,
+ hashes_buffer,
+ random_state
+ );
}
DataType::Boolean => {
- hash_array!(BooleanArray, col, u8, hashes, random_state);
+ hash_array!(BooleanArray, col, u8, hashes_buffer, random_state);
}
DataType::Utf8 => {
- hash_array!(StringArray, col, str, hashes, random_state);
+ hash_array!(StringArray, col, str, hashes_buffer, random_state);
}
_ => {
// This is internal because we should have caught this before.
@@ -752,7 +789,7 @@
}
}
}
- Ok(hashes)
+ Ok(hashes_buffer)
}
impl Stream for HashJoinStream {
@@ -1136,8 +1173,9 @@
);
let random_state = RandomState::new();
-
- let hashes = create_hashes(&[left.columns()[0].clone()], &random_state)?;
+ let hashes_buff = &mut vec![0; left.num_rows()];
+ let hashes =
+ create_hashes(&[left.columns()[0].clone()], &random_state, hashes_buff)?;
// Create hash collisions
hashmap_left.insert(hashes[0], vec![0, 1]);
diff --git a/rust/datafusion/src/physical_plan/repartition.rs b/rust/datafusion/src/physical_plan/repartition.rs
index f62ae23..f21a144 100644
--- a/rust/datafusion/src/physical_plan/repartition.rs
+++ b/rust/datafusion/src/physical_plan/repartition.rs
@@ -163,7 +163,9 @@
})
.collect::<Result<Vec<_>>>()?;
// Hash arrays and compute buckets based on number of partitions
- let hashes = create_hashes(&arrays, &random_state)?;
+ let hashes_buf = &mut vec![0; arrays[0].len()];
+ let hashes =
+ create_hashes(&arrays, &random_state, hashes_buf)?;
let mut indices = vec![vec![]; num_output_partitions];
for (index, hash) in hashes.iter().enumerate() {
indices