blob: 619e1fce393ae486eb8de048b8a548fdc2eabb2c [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# ---------------------------------------------------------------------
# Implement Internal ExecPlan bindings
# cython: profile=False
# distutils: language = c++
# cython: language_level = 3
from pyarrow.lib import Table
from pyarrow.compute import Expression, field
try:
from pyarrow._acero import ( # noqa
Declaration,
ExecNodeOptions,
TableSourceNodeOptions,
FilterNodeOptions,
ProjectNodeOptions,
AggregateNodeOptions,
OrderByNodeOptions,
HashJoinNodeOptions,
AsofJoinNodeOptions,
)
except ImportError as exc:
raise ImportError(
f"The pyarrow installation is not built with support for 'acero' ({str(exc)})"
) from None
try:
import pyarrow.dataset as ds
from pyarrow._dataset import ScanNodeOptions
except ImportError:
class DatasetModuleStub:
class Dataset:
pass
class InMemoryDataset:
pass
ds = DatasetModuleStub
def _dataset_to_decl(dataset, use_threads=True):
decl = Declaration("scan", ScanNodeOptions(dataset, use_threads=use_threads))
# Get rid of special dataset columns
# "__fragment_index", "__batch_index", "__last_in_fragment", "__filename"
projections = [field(f) for f in dataset.schema.names]
decl = Declaration.from_sequence(
[decl, Declaration("project", ProjectNodeOptions(projections))]
)
filter_expr = dataset._scan_options.get("filter")
if filter_expr is not None:
# Filters applied in CScanNodeOptions are "best effort" for the scan node itself
# so we always need to inject an additional Filter node to apply them for real.
decl = Declaration.from_sequence(
[decl, Declaration("filter", FilterNodeOptions(filter_expr))]
)
return decl
def _perform_join(join_type, left_operand, left_keys,
right_operand, right_keys,
left_suffix=None, right_suffix=None,
use_threads=True, coalesce_keys=False,
output_type=Table):
"""
Perform join of two tables or datasets.
The result will be an output table with the result of the join operation
Parameters
----------
join_type : str
One of supported join types.
left_operand : Table or Dataset
The left operand for the join operation.
left_keys : str or list[str]
The left key (or keys) on which the join operation should be performed.
right_operand : Table or Dataset
The right operand for the join operation.
right_keys : str or list[str]
The right key (or keys) on which the join operation should be performed.
left_suffix : str, default None
Which suffix to add to left column names. This prevents confusion
when the columns in left and right operands have colliding names.
right_suffix : str, default None
Which suffix to add to the right column names. This prevents confusion
when the columns in left and right operands have colliding names.
use_threads : bool, default True
Whether to use multithreading or not.
coalesce_keys : bool, default False
If the duplicated keys should be omitted from one of the sides
in the join result.
output_type: Table or InMemoryDataset
The output type for the exec plan result.
Returns
-------
result_table : Table or InMemoryDataset
"""
if not isinstance(left_operand, (Table, ds.Dataset)):
raise TypeError(f"Expected Table or Dataset, got {type(left_operand)}")
if not isinstance(right_operand, (Table, ds.Dataset)):
raise TypeError(f"Expected Table or Dataset, got {type(right_operand)}")
# Prepare left and right tables Keys to send them to the C++ function
left_keys_order = {}
if not isinstance(left_keys, (tuple, list)):
left_keys = [left_keys]
for idx, key in enumerate(left_keys):
left_keys_order[key] = idx
right_keys_order = {}
if not isinstance(right_keys, (list, tuple)):
right_keys = [right_keys]
for idx, key in enumerate(right_keys):
right_keys_order[key] = idx
# By default expose all columns on both left and right table
left_columns = left_operand.schema.names
right_columns = right_operand.schema.names
# Pick the join type
if join_type == "left semi" or join_type == "left anti":
right_columns = []
elif join_type == "right semi" or join_type == "right anti":
left_columns = []
elif join_type == "inner" or join_type == "left outer":
right_columns = [
col for col in right_columns if col not in right_keys_order
]
elif join_type == "right outer":
left_columns = [
col for col in left_columns if col not in left_keys_order
]
# Turn the columns to vectors of FieldRefs
# and set aside indices of keys.
left_column_keys_indices = {}
for idx, colname in enumerate(left_columns):
if colname in left_keys:
left_column_keys_indices[colname] = idx
right_column_keys_indices = {}
for idx, colname in enumerate(right_columns):
if colname in right_keys:
right_column_keys_indices[colname] = idx
# Add the join node to the execplan
if isinstance(left_operand, ds.Dataset):
left_source = _dataset_to_decl(left_operand, use_threads=use_threads)
else:
left_source = Declaration("table_source", TableSourceNodeOptions(left_operand))
if isinstance(right_operand, ds.Dataset):
right_source = _dataset_to_decl(right_operand, use_threads=use_threads)
else:
right_source = Declaration(
"table_source", TableSourceNodeOptions(right_operand)
)
if coalesce_keys:
join_opts = HashJoinNodeOptions(
join_type, left_keys, right_keys, left_columns, right_columns,
output_suffix_for_left=left_suffix or "",
output_suffix_for_right=right_suffix or "",
)
else:
join_opts = HashJoinNodeOptions(
join_type, left_keys, right_keys,
output_suffix_for_left=left_suffix or "",
output_suffix_for_right=right_suffix or "",
)
decl = Declaration(
"hashjoin", options=join_opts, inputs=[left_source, right_source]
)
if coalesce_keys and join_type == "full outer":
# In case of full outer joins, the join operation will output all columns
# so that we can coalesce the keys and exclude duplicates in a subsequent
# projection.
left_columns_set = set(left_columns)
right_columns_set = set(right_columns)
# Where the right table columns start.
right_operand_index = len(left_columns)
projected_col_names = []
projections = []
for idx, col in enumerate(left_columns + right_columns):
if idx < len(left_columns) and col in left_column_keys_indices:
# Include keys only once and coalesce left+right table keys.
projected_col_names.append(col)
# Get the index of the right key that is being paired
# with this left key. We do so by retrieving the name
# of the right key that is in the same position in the provided keys
# and then looking up the index for that name in the right table.
right_key_index = right_column_keys_indices[
right_keys[left_keys_order[col]]]
projections.append(
Expression._call("coalesce", [
Expression._field(idx), Expression._field(
right_operand_index+right_key_index)
])
)
elif idx >= right_operand_index and col in right_column_keys_indices:
# Do not include right table keys. As they would lead to duplicated keys
continue
else:
# For all the other columns include them as they are.
# Just recompute the suffixes that the join produced as the projection
# would lose them otherwise.
if (
left_suffix and idx < right_operand_index
and col in right_columns_set
):
col += left_suffix
if (
right_suffix and idx >= right_operand_index
and col in left_columns_set
):
col += right_suffix
projected_col_names.append(col)
projections.append(
Expression._field(idx)
)
projection = Declaration(
"project", ProjectNodeOptions(projections, projected_col_names)
)
decl = Declaration.from_sequence([decl, projection])
result_table = decl.to_table(use_threads=use_threads)
if output_type == Table:
return result_table
elif output_type == ds.InMemoryDataset:
return ds.InMemoryDataset(result_table)
else:
raise TypeError("Unsupported output type")
def _perform_join_asof(left_operand, left_on, left_by,
right_operand, right_on, right_by,
tolerance, use_threads=True,
output_type=Table):
"""
Perform asof join of two tables or datasets.
The result will be an output table with the result of the join operation
Parameters
----------
left_operand : Table or Dataset
The left operand for the join operation.
left_on : str
The left key (or keys) on which the join operation should be performed.
left_by: str or list[str]
The left key (or keys) on which the join operation should be performed.
right_operand : Table or Dataset
The right operand for the join operation.
right_on : str or list[str]
The right key (or keys) on which the join operation should be performed.
right_by: str or list[str]
The right key (or keys) on which the join operation should be performed.
tolerance : int
The tolerance to use for the asof join. The tolerance is interpreted in
the same units as the "on" key.
output_type: Table or InMemoryDataset
The output type for the exec plan result.
Returns
-------
result_table : Table or InMemoryDataset
"""
if not isinstance(left_operand, (Table, ds.Dataset)):
raise TypeError(f"Expected Table or Dataset, got {type(left_operand)}")
if not isinstance(right_operand, (Table, ds.Dataset)):
raise TypeError(f"Expected Table or Dataset, got {type(right_operand)}")
if not isinstance(left_by, (tuple, list)):
left_by = [left_by]
if not isinstance(right_by, (tuple, list)):
right_by = [right_by]
# AsofJoin does not return on or by columns for right_operand.
right_columns = [
col for col in right_operand.schema.names
if col not in [right_on] + right_by
]
columns_collisions = set(left_operand.schema.names) & set(right_columns)
if columns_collisions:
raise ValueError(
"Columns {} present in both tables. AsofJoin does not support "
"column collisions.".format(columns_collisions),
)
# Add the join node to the execplan
if isinstance(left_operand, ds.Dataset):
left_source = _dataset_to_decl(left_operand, use_threads=use_threads)
else:
left_source = Declaration(
"table_source", TableSourceNodeOptions(left_operand),
)
if isinstance(right_operand, ds.Dataset):
right_source = _dataset_to_decl(right_operand, use_threads=use_threads)
else:
right_source = Declaration(
"table_source", TableSourceNodeOptions(right_operand)
)
join_opts = AsofJoinNodeOptions(
left_on, left_by, right_on, right_by, tolerance
)
decl = Declaration(
"asofjoin", options=join_opts, inputs=[left_source, right_source]
)
result_table = decl.to_table(use_threads=use_threads)
if output_type == Table:
return result_table
elif output_type == ds.InMemoryDataset:
return ds.InMemoryDataset(result_table)
else:
raise TypeError("Unsupported output type")
def _filter_table(table, expression):
"""Filter rows of a table based on the provided expression.
The result will be an output table with only the rows matching
the provided expression.
Parameters
----------
table : Table or Dataset
Table or Dataset that should be filtered.
expression : Expression
The expression on which rows should be filtered.
Returns
-------
Table
"""
decl = Declaration.from_sequence([
Declaration("table_source", options=TableSourceNodeOptions(table)),
Declaration("filter", options=FilterNodeOptions(expression))
])
return decl.to_table(use_threads=True)
def _sort_source(table_or_dataset, sort_keys, output_type=Table, **kwargs):
if isinstance(table_or_dataset, ds.Dataset):
data_source = _dataset_to_decl(table_or_dataset, use_threads=True)
else:
data_source = Declaration(
"table_source", TableSourceNodeOptions(table_or_dataset)
)
order_by = Declaration("order_by", OrderByNodeOptions(sort_keys, **kwargs))
decl = Declaration.from_sequence([data_source, order_by])
result_table = decl.to_table(use_threads=True)
if output_type == Table:
return result_table
elif output_type == ds.InMemoryDataset:
return ds.InMemoryDataset(result_table)
else:
raise TypeError("Unsupported output type")
def _group_by(table, aggregates, keys, use_threads=True):
decl = Declaration.from_sequence([
Declaration("table_source", TableSourceNodeOptions(table)),
Declaration("aggregate", AggregateNodeOptions(aggregates, keys=keys))
])
return decl.to_table(use_threads=use_threads)