| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from datafusion import SessionContext, col, lit, udf, functions as F |
| import os |
| import pyarrow as pa |
| import pyarrow.compute as pc |
| import time |
| |
| path = os.path.dirname(os.path.abspath(__file__)) |
| filepath = os.path.join(path, "./tpch/data/lineitem.parquet") |
| |
| # This example serves to demonstrate alternate approaches to answering the |
| # question "return all of the rows that have a specific combination of these |
| # values". We have the combinations we care about provided as a python |
| # list of tuples. There is no built in function that supports this operation, |
| # but it can be explicitly specified via a single expression or we can |
| # use a user defined function. |
| |
| ctx = SessionContext() |
| |
| # These part keys and suppliers are chosen because there are |
| # cases where two suppliers each have two of the part keys |
| # but we are interested in these specific combinations. |
| |
| values_of_interest = [ |
| (1530, 4031, "N"), |
| (6530, 1531, "N"), |
| (5618, 619, "N"), |
| (8118, 8119, "N"), |
| ] |
| |
| partkeys = [lit(r[0]) for r in values_of_interest] |
| suppkeys = [lit(r[1]) for r in values_of_interest] |
| returnflags = [lit(r[2]) for r in values_of_interest] |
| |
| df_lineitem = ctx.read_parquet(filepath).select( |
| "l_partkey", "l_suppkey", "l_returnflag" |
| ) |
| |
| start_time = time.time() |
| |
| df_simple_filter = df_lineitem.filter( |
| F.in_list(col("l_partkey"), partkeys), |
| F.in_list(col("l_suppkey"), suppkeys), |
| F.in_list(col("l_returnflag"), returnflags), |
| ) |
| |
| num_rows = df_simple_filter.count() |
| print( |
| f"Simple filtering has number {num_rows} rows and took {time.time() - start_time} s" |
| ) |
| print("This is the incorrect number of rows!") |
| start_time = time.time() |
| |
| # Explicitly check for the combinations of interest. |
| # This works but is not scalable. |
| |
| filter_expr = ( |
| ( |
| (col("l_partkey") == values_of_interest[0][0]) |
| & (col("l_suppkey") == values_of_interest[0][1]) |
| & (col("l_returnflag") == values_of_interest[0][2]) |
| ) |
| | ( |
| (col("l_partkey") == values_of_interest[1][0]) |
| & (col("l_suppkey") == values_of_interest[1][1]) |
| & (col("l_returnflag") == values_of_interest[1][2]) |
| ) |
| | ( |
| (col("l_partkey") == values_of_interest[2][0]) |
| & (col("l_suppkey") == values_of_interest[2][1]) |
| & (col("l_returnflag") == values_of_interest[2][2]) |
| ) |
| | ( |
| (col("l_partkey") == values_of_interest[3][0]) |
| & (col("l_suppkey") == values_of_interest[3][1]) |
| & (col("l_returnflag") == values_of_interest[3][2]) |
| ) |
| ) |
| |
| df_explicit_filter = df_lineitem.filter(filter_expr) |
| |
| num_rows = df_explicit_filter.count() |
| print( |
| f"Explicit filtering has number {num_rows} rows and took {time.time() - start_time} s" |
| ) |
| start_time = time.time() |
| |
| # Instead try a python UDF |
| |
| |
| def is_of_interest_impl( |
| partkey_arr: pa.Array, |
| suppkey_arr: pa.Array, |
| returnflag_arr: pa.Array, |
| ) -> pa.Array: |
| result = [] |
| for idx, partkey in enumerate(partkey_arr): |
| partkey = partkey.as_py() |
| suppkey = suppkey_arr[idx].as_py() |
| returnflag = returnflag_arr[idx].as_py() |
| value = (partkey, suppkey, returnflag) |
| result.append(value in values_of_interest) |
| |
| return pa.array(result) |
| |
| |
| is_of_interest = udf( |
| is_of_interest_impl, |
| [pa.int64(), pa.int64(), pa.utf8()], |
| pa.bool_(), |
| "stable", |
| ) |
| |
| df_udf_filter = df_lineitem.filter( |
| is_of_interest(col("l_partkey"), col("l_suppkey"), col("l_returnflag")) |
| ) |
| |
| num_rows = df_udf_filter.count() |
| print(f"UDF filtering has number {num_rows} rows and took {time.time() - start_time} s") |
| start_time = time.time() |
| |
| # Now use a user defined function but lean on the built in pyarrow array |
| # functions so we never convert rows to python objects. |
| |
| # To see other pyarrow compute functions see |
| # https://arrow.apache.org/docs/python/api/compute.html |
| # |
| # It is important that the number of rows in the returned array |
| # matches the original array, so we cannot use functions like |
| # filtered_partkey_arr.filter(filtered_suppkey_arr). |
| |
| |
| def udf_using_pyarrow_compute_impl( |
| partkey_arr: pa.Array, |
| suppkey_arr: pa.Array, |
| returnflag_arr: pa.Array, |
| ) -> pa.Array: |
| results = None |
| for partkey, suppkey, returnflag in values_of_interest: |
| filtered_partkey_arr = pc.equal(partkey_arr, partkey) |
| filtered_suppkey_arr = pc.equal(suppkey_arr, suppkey) |
| filtered_returnflag_arr = pc.equal(returnflag_arr, returnflag) |
| |
| resultant_arr = pc.and_(filtered_partkey_arr, filtered_suppkey_arr) |
| resultant_arr = pc.and_(resultant_arr, filtered_returnflag_arr) |
| |
| if results is None: |
| results = resultant_arr |
| else: |
| results = pc.or_(results, resultant_arr) |
| |
| return results |
| |
| |
| udf_using_pyarrow_compute = udf( |
| udf_using_pyarrow_compute_impl, |
| [pa.int64(), pa.int64(), pa.utf8()], |
| pa.bool_(), |
| "stable", |
| ) |
| |
| df_udf_pyarrow_compute = df_lineitem.filter( |
| udf_using_pyarrow_compute(col("l_partkey"), col("l_suppkey"), col("l_returnflag")) |
| ) |
| |
| num_rows = df_udf_pyarrow_compute.count() |
| print( |
| f"UDF filtering using pyarrow compute has number {num_rows} rows and took {time.time() - start_time} s" |
| ) |
| start_time = time.time() |