blob: bff5182f8e16e0f4af678ac202a3a22ae4f63af0 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
A simple example demonstrating Python UDTFs in Spark
Run with:
./bin/spark-submit examples/src/main/python/sql/udtf.py
"""
# NOTE that this file is imported in the tutorials in PySpark documentation.
# The codes are referred via line numbers. See also `literalinclude` directive in Sphinx.
from pyspark.sql import SparkSession
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
# Python UDTFs use Arrow by default.
require_minimum_pandas_version()
require_minimum_pyarrow_version()
def python_udtf_simple_example(spark: SparkSession) -> None:
# Define the UDTF class and implement the required `eval` method.
class SquareNumbers:
def eval(self, start: int, end: int):
for num in range(start, end + 1):
yield (num, num * num)
from pyspark.sql.functions import lit, udtf
# Create a UDTF using the class definition and the `udtf` function.
square_num = udtf(SquareNumbers, returnType="num: int, squared: int")
# Invoke the UDTF in PySpark.
square_num(lit(1), lit(3)).show()
# +---+-------+
# |num|squared|
# +---+-------+
# | 1| 1|
# | 2| 4|
# | 3| 9|
# +---+-------+
def python_udtf_decorator_example(spark: SparkSession) -> None:
from pyspark.sql.functions import lit, udtf
# Define a UDTF using the `udtf` decorator directly on the class.
@udtf(returnType="num: int, squared: int")
class SquareNumbers:
def eval(self, start: int, end: int):
for num in range(start, end + 1):
yield (num, num * num)
# Invoke the UDTF in PySpark using the SquareNumbers class directly.
SquareNumbers(lit(1), lit(3)).show()
# +---+-------+
# |num|squared|
# +---+-------+
# | 1| 1|
# | 2| 4|
# | 3| 9|
# +---+-------+
def python_udtf_registration(spark: SparkSession) -> None:
from pyspark.sql.functions import udtf
@udtf(returnType="word: string")
class WordSplitter:
def eval(self, text: str):
for word in text.split(" "):
yield (word.strip(),)
# Register the UDTF for use in Spark SQL.
spark.udtf.register("split_words", WordSplitter)
# Example: Using the UDTF in SQL.
spark.sql("SELECT * FROM split_words('hello world')").show()
# +-----+
# | word|
# +-----+
# |hello|
# |world|
# +-----+
# Example: Using the UDTF with a lateral join in SQL.
# The lateral join allows us to reference the columns and aliases
# in the previous FROM clause items as inputs to the UDTF.
spark.sql(
"SELECT * FROM VALUES ('Hello World'), ('Apache Spark') t(text), "
"LATERAL split_words(text)"
).show()
# +------------+------+
# | text| word|
# +------------+------+
# | Hello World| Hello|
# | Hello World| World|
# |Apache Spark|Apache|
# |Apache Spark| Spark|
# +------------+------+
def python_udtf_arrow_example(spark: SparkSession) -> None:
from pyspark.sql.functions import udtf
@udtf(returnType="c1: int, c2: int", useArrow=True)
class PlusOne:
def eval(self, x: int):
yield x, x + 1
def python_udtf_date_expander_example(spark: SparkSession) -> None:
from datetime import datetime, timedelta
from pyspark.sql.functions import lit, udtf
@udtf(returnType="date: string")
class DateExpander:
def eval(self, start_date: str, end_date: str):
current = datetime.strptime(start_date, '%Y-%m-%d')
end = datetime.strptime(end_date, '%Y-%m-%d')
while current <= end:
yield (current.strftime('%Y-%m-%d'),)
current += timedelta(days=1)
DateExpander(lit("2023-02-25"), lit("2023-03-01")).show()
# +----------+
# | date|
# +----------+
# |2023-02-25|
# |2023-02-26|
# |2023-02-27|
# |2023-02-28|
# |2023-03-01|
# +----------+
def python_udtf_terminate_example(spark: SparkSession) -> None:
from pyspark.sql.functions import udtf
@udtf(returnType="cnt: int")
class CountUDTF:
def __init__(self):
# Initialize the counter to 0 when an instance of the class is created.
self.count = 0
def eval(self, x: int):
# Increment the counter by 1 for each input value received.
self.count += 1
def terminate(self):
# Yield the final count when the UDTF is done processing.
yield self.count,
spark.udtf.register("count_udtf", CountUDTF)
spark.sql("SELECT * FROM range(0, 10, 1, 1), LATERAL count_udtf(id)").show()
# +---+---+
# | id|cnt|
# +---+---+
# | 9| 10|
# +---+---+
spark.sql("SELECT * FROM range(0, 10, 1, 2), LATERAL count_udtf(id)").show()
# +---+---+
# | id|cnt|
# +---+---+
# | 4| 5|
# | 9| 5|
# +---+---+
def python_udtf_table_argument(spark: SparkSession) -> None:
from pyspark.sql.functions import udtf
from pyspark.sql.types import Row
@udtf(returnType="id: int")
class FilterUDTF:
def eval(self, row: Row):
if row["id"] > 5:
yield row["id"],
spark.udtf.register("filter_udtf", FilterUDTF)
spark.sql("SELECT * FROM filter_udtf(TABLE(SELECT * FROM range(10)))").show()
# +---+
# | id|
# +---+
# | 6|
# | 7|
# | 8|
# | 9|
# +---+
def python_udtf_table_argument_with_partitioning(spark: SparkSession) -> None:
from pyspark.sql.functions import udtf
from pyspark.sql.types import Row
# Define and register a UDTF.
@udtf(returnType="a: string, b: int")
class FilterUDTF:
def __init__(self):
self.key = ""
self.max = 0
def eval(self, row: Row):
self.key = row["a"]
self.max = max(self.max, row["b"])
def terminate(self):
yield self.key, self.max
spark.udtf.register("filter_udtf", FilterUDTF)
# Create an input table with some example values.
spark.sql("DROP TABLE IF EXISTS values_table")
spark.sql("CREATE TABLE values_table (a STRING, b INT)")
spark.sql("INSERT INTO values_table VALUES ('abc', 2), ('abc', 4), ('def', 6), ('def', 8)")
spark.table("values_table").show()
# +-------+----+
# | a | b |
# +-------+----+
# | "abc" | 2 |
# | "abc" | 4 |
# | "def" | 6 |
# | "def" | 8 |
# +-------+----+
# Query the UDTF with the input table as an argument, and a directive to partition the input
# rows such that all rows with each unique value of the `a` column are processed by the same
# instance of the UDTF class. Within each partition, the rows are ordered by the `b` column.
spark.sql("""
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY a ORDER BY b) ORDER BY 1
""").show()
# +-------+----+
# | a | b |
# +-------+----+
# | "abc" | 4 |
# | "def" | 8 |
# +-------+----+
# Query the UDTF with the input table as an argument, and a directive to partition the input
# rows such that all rows with each unique result of evaluating the "LENGTH(a)" expression are
# processed by the same instance of the UDTF class. Within each partition, the rows are ordered
# by the `b` column.
spark.sql("""
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY LENGTH(a) ORDER BY b) ORDER BY 1
""").show()
# +-------+---+
# | a | b |
# +-------+---+
# | "def" | 8 |
# +-------+---+
# Query the UDTF with the input table as an argument, and a directive to consider all the input
# rows in one single partition such that exactly once instance of the UDTF class consumes all of
# the input rows. Within each partition, the rows are ordered by the `b` column.
spark.sql("""
SELECT * FROM filter_udtf(TABLE(values_table) WITH SINGLE PARTITION ORDER BY b) ORDER BY 1
""").show()
# +-------+----+
# | a | b |
# +-------+----+
# | "def" | 8 |
# +-------+----+
# Clean up.
spark.sql("DROP TABLE values_table")
if __name__ == "__main__":
spark = SparkSession \
.builder \
.appName("Python UDTF example") \
.getOrCreate()
print("Running Python UDTF single example")
python_udtf_simple_example(spark)
print("Running Python UDTF decorator example")
python_udtf_decorator_example(spark)
print("Running Python UDTF registration example")
python_udtf_registration(spark)
print("Running Python UDTF arrow example")
python_udtf_arrow_example(spark)
print("Running Python UDTF date expander example")
python_udtf_date_expander_example(spark)
print("Running Python UDTF terminate example")
python_udtf_terminate_example(spark)
print("Running Python UDTF table argument example")
python_udtf_table_argument(spark)
print("Running Python UDTF table argument with partitioning example")
python_udtf_table_argument_with_partitioning(spark)
spark.stop()