blob: f21a1443db640edf3dc7a1db74e510062c1a5a03 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! The repartition operator maps N input partitions to M output partitions based on a
//! partitioning scheme.
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{any::Any, collections::HashMap, vec};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{ExecutionPlan, Partitioning};
use arrow::record_batch::RecordBatch;
use arrow::{array::Array, error::Result as ArrowResult};
use arrow::{compute::take, datatypes::SchemaRef};
use tokio_stream::wrappers::UnboundedReceiverStream;
use super::{hash_join::create_hashes, RecordBatchStream, SendableRecordBatchStream};
use async_trait::async_trait;
use futures::stream::Stream;
use futures::StreamExt;
use tokio::sync::{
mpsc::{UnboundedReceiver, UnboundedSender},
Mutex,
};
use tokio::task::JoinHandle;
type MaybeBatch = Option<ArrowResult<RecordBatch>>;
/// The repartition operator maps N input partitions to M output partitions based on a
/// partitioning scheme. No guarantees are made about the order of the resulting partitions.
#[derive(Debug)]
pub struct RepartitionExec {
/// Input execution plan
input: Arc<dyn ExecutionPlan>,
/// Partitioning scheme to use
partitioning: Partitioning,
/// Channels for sending batches from input partitions to output partitions.
/// Key is the partition number
channels: Arc<
Mutex<
HashMap<usize, (UnboundedSender<MaybeBatch>, UnboundedReceiver<MaybeBatch>)>,
>,
>,
}
impl RepartitionExec {
/// Input execution plan
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
/// Partitioning scheme to use
pub fn partitioning(&self) -> &Partitioning {
&self.partitioning
}
}
#[async_trait]
impl ExecutionPlan for RepartitionExec {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}
/// Get the schema for this execution plan
fn schema(&self) -> SchemaRef {
self.input.schema()
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
}
fn with_new_children(
&self,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
match children.len() {
1 => Ok(Arc::new(RepartitionExec::try_new(
children[0].clone(),
self.partitioning.clone(),
)?)),
_ => Err(DataFusionError::Internal(
"RepartitionExec wrong number of children".to_string(),
)),
}
}
fn output_partitioning(&self) -> Partitioning {
self.partitioning.clone()
}
async fn execute(&self, partition: usize) -> Result<SendableRecordBatchStream> {
// lock mutexes
let mut channels = self.channels.lock().await;
let num_input_partitions = self.input.output_partitioning().partition_count();
let num_output_partitions = self.partitioning.partition_count();
// if this is the first partition to be invoked then we need to set up initial state
if channels.is_empty() {
// create one channel per *output* partition
for partition in 0..num_output_partitions {
// Note that this operator uses unbounded channels to avoid deadlocks because
// the output partitions can be read in any order and this could cause input
// partitions to be blocked when sending data to output UnboundedReceivers that are not
// being read yet. This may cause high memory usage if the next operator is
// reading output partitions in order rather than concurrently. One workaround
// for this would be to add spill-to-disk capabilities.
let (sender, receiver) = tokio::sync::mpsc::unbounded_channel::<
Option<ArrowResult<RecordBatch>>,
>();
channels.insert(partition, (sender, receiver));
}
let random = ahash::RandomState::new();
// launch one async task per *input* partition
for i in 0..num_input_partitions {
let random_state = random.clone();
let input = self.input.clone();
let mut txs: HashMap<_, _> = channels
.iter()
.map(|(partition, (tx, _rx))| (*partition, tx.clone()))
.collect();
let partitioning = self.partitioning.clone();
let _: JoinHandle<Result<()>> = tokio::spawn(async move {
let mut stream = input.execute(i).await?;
let mut counter = 0;
while let Some(result) = stream.next().await {
match &partitioning {
Partitioning::RoundRobinBatch(_) => {
let output_partition = counter % num_output_partitions;
let tx = txs.get_mut(&output_partition).unwrap();
tx.send(Some(result)).map_err(|e| {
DataFusionError::Execution(e.to_string())
})?;
}
Partitioning::Hash(exprs, _) => {
let input_batch = result?;
let arrays = exprs
.iter()
.map(|expr| {
Ok(expr
.evaluate(&input_batch)?
.into_array(input_batch.num_rows()))
})
.collect::<Result<Vec<_>>>()?;
// Hash arrays and compute buckets based on number of partitions
let hashes_buf = &mut vec![0; arrays[0].len()];
let hashes =
create_hashes(&arrays, &random_state, hashes_buf)?;
let mut indices = vec![vec![]; num_output_partitions];
for (index, hash) in hashes.iter().enumerate() {
indices
[(*hash % num_output_partitions as u64) as usize]
.push(index as u64)
}
for (num_output_partition, partition_indices) in
indices.into_iter().enumerate()
{
let indices = partition_indices.into();
// Produce batches based on indices
let columns = input_batch
.columns()
.iter()
.map(|c| {
take(c.as_ref(), &indices, None).map_err(
|e| {
DataFusionError::Execution(
e.to_string(),
)
},
)
})
.collect::<Result<Vec<Arc<dyn Array>>>>()?;
let output_batch = RecordBatch::try_new(
input_batch.schema(),
columns,
);
let tx = txs.get_mut(&num_output_partition).unwrap();
tx.send(Some(output_batch)).map_err(|e| {
DataFusionError::Execution(e.to_string())
})?;
}
}
other => {
// this should be unreachable as long as the validation logic
// in the constructor is kept up-to-date
return Err(DataFusionError::NotImplemented(format!(
"Unsupported repartitioning scheme {:?}",
other
)));
}
}
counter += 1;
}
// notify each output partition that this input partition has no more data
for (_, tx) in txs {
tx.send(None)
.map_err(|e| DataFusionError::Execution(e.to_string()))?;
}
Ok(())
});
}
}
// now return stream for the specified *output* partition which will
// read from the channel
Ok(Box::pin(RepartitionStream {
num_input_partitions,
num_input_partitions_processed: 0,
schema: self.input.schema(),
input: UnboundedReceiverStream::new(channels.remove(&partition).unwrap().1),
}))
}
}
impl RepartitionExec {
/// Create a new RepartitionExec
pub fn try_new(
input: Arc<dyn ExecutionPlan>,
partitioning: Partitioning,
) -> Result<Self> {
Ok(RepartitionExec {
input,
partitioning,
channels: Arc::new(Mutex::new(HashMap::new())),
})
}
}
struct RepartitionStream {
/// Number of input partitions that will be sending batches to this output channel
num_input_partitions: usize,
/// Number of input partitions that have finished sending batches to this output channel
num_input_partitions_processed: usize,
/// Schema
schema: SchemaRef,
/// channel containing the repartitioned batches
input: UnboundedReceiverStream<Option<ArrowResult<RecordBatch>>>,
}
impl Stream for RepartitionStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.input.poll_next_unpin(cx) {
Poll::Ready(Some(Some(v))) => Poll::Ready(Some(v)),
Poll::Ready(Some(None)) => {
self.num_input_partitions_processed += 1;
if self.num_input_partitions == self.num_input_partitions_processed {
// all input partitions have finished sending batches
Poll::Ready(None)
} else {
// other partitions still have data to send
self.poll_next(cx)
}
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl RecordBatchStream for RepartitionStream {
/// Get the schema
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::physical_plan::memory::MemoryExec;
use arrow::array::UInt32Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
#[tokio::test]
async fn one_to_many_round_robin() -> Result<()> {
// define input partitions
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions = vec![partition];
// repartition from 1 input to 4 output
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
assert_eq!(4, output_partitions.len());
assert_eq!(13, output_partitions[0].len());
assert_eq!(13, output_partitions[1].len());
assert_eq!(12, output_partitions[2].len());
assert_eq!(12, output_partitions[3].len());
Ok(())
}
#[tokio::test]
async fn many_to_one_round_robin() -> Result<()> {
// define input partitions
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
// repartition from 3 input to 1 output
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
assert_eq!(1, output_partitions.len());
assert_eq!(150, output_partitions[0].len());
Ok(())
}
#[tokio::test]
async fn many_to_many_round_robin() -> Result<()> {
// define input partitions
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
// repartition from 3 input to 5 output
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
assert_eq!(5, output_partitions.len());
assert_eq!(30, output_partitions[0].len());
assert_eq!(30, output_partitions[1].len());
assert_eq!(30, output_partitions[2].len());
assert_eq!(30, output_partitions[3].len());
assert_eq!(30, output_partitions[4].len());
Ok(())
}
#[tokio::test]
async fn many_to_many_hash_partition() -> Result<()> {
// define input partitions
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
let output_partitions = repartition(
&schema,
partitions,
Partitioning::Hash(
vec![Arc::new(crate::physical_plan::expressions::Column::new(
&"c0",
))],
8,
),
)
.await?;
let total_rows: usize = output_partitions.iter().map(|x| x.len()).sum();
assert_eq!(8, output_partitions.len());
assert_eq!(total_rows, 8 * 50 * 3);
Ok(())
}
fn test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
}
fn create_vec_batches(schema: &Arc<Schema>, n: usize) -> Vec<RecordBatch> {
let batch = create_batch(schema);
let mut vec = Vec::with_capacity(n);
for _ in 0..n {
vec.push(batch.clone());
}
vec
}
fn create_batch(schema: &Arc<Schema>) -> RecordBatch {
RecordBatch::try_new(
schema.clone(),
vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
)
.unwrap()
}
async fn repartition(
schema: &SchemaRef,
input_partitions: Vec<Vec<RecordBatch>>,
partitioning: Partitioning,
) -> Result<Vec<Vec<RecordBatch>>> {
// create physical plan
let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?;
let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?;
// execute and collect results
let mut output_partitions = vec![];
for i in 0..exec.partitioning.partition_count() {
// execute this *output* partition and collect all batches
let mut stream = exec.execute(i).await?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
}
output_partitions.push(batches);
}
Ok(output_partitions)
}
#[tokio::test]
async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
let join_handle: JoinHandle<Result<Vec<Vec<RecordBatch>>>> =
tokio::spawn(async move {
// define input partitions
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions =
vec![partition.clone(), partition.clone(), partition.clone()];
// repartition from 3 input to 5 output
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
});
let output_partitions = join_handle
.await
.map_err(|e| DataFusionError::Internal(e.to_string()))??;
assert_eq!(5, output_partitions.len());
assert_eq!(30, output_partitions[0].len());
assert_eq!(30, output_partitions[1].len());
assert_eq!(30, output_partitions[2].len());
assert_eq!(30, output_partitions[3].len());
assert_eq!(30, output_partitions[4].len());
Ok(())
}
}