| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "id": "4ee2125b-f889-47e6-9c3d-8bd63a253683", |
| "metadata": {}, |
| "source": [ |
| "# Testing PySpark\n", |
| "\n", |
| "This guide is a reference for writing robust tests for PySpark code.\n", |
| "\n", |
| "To view the docs for PySpark test utils, see here. To see the code for PySpark built-in test utils, check out the Spark repository here. To see the JIRA board tickets for the PySpark test framework, see here." |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "0e8ee4b6-9544-45e1-8a91-e71ed8ef8b9d", |
| "metadata": {}, |
| "source": [ |
| "## Build a PySpark Application\n", |
| "Here is an example for how to start a PySpark application. Feel free to skip to the next section, “Testing your PySpark Application,” if you already have an application you’re ready to test.\n", |
| "\n", |
| "First, start your Spark Session." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "id": "9af4a35b-17e8-4e45-816b-34c14c5902f7", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "from pyspark.sql import SparkSession \n", |
| "from pyspark.sql.functions import col \n", |
| "\n", |
| "# Create a SparkSession \n", |
| "spark = SparkSession.builder.appName(\"Testing PySpark Example\").getOrCreate() " |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "4a4c6efe-91f5-4e18-b4b2-b0401c2368e4", |
| "metadata": {}, |
| "source": [ |
| "Next, create a DataFrame." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "id": "3b483dd8-3a76-41c6-9206-301d7ef314d6", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "sample_data = [{\"name\": \"John D.\", \"age\": 30}, \n", |
| " {\"name\": \"Alice G.\", \"age\": 25}, \n", |
| " {\"name\": \"Bob T.\", \"age\": 35}, \n", |
| " {\"name\": \"Eve A.\", \"age\": 28}] \n", |
| "\n", |
| "df = spark.createDataFrame(sample_data)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "e0f44333-0e08-470b-9fa2-38f59e3dbd63", |
| "metadata": {}, |
| "source": [ |
| "Now, let’s define and apply a transformation function to our DataFrame." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 7, |
| "id": "a6c0b766-af5f-4e1d-acf8-887d7cf0b0b2", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "+---+--------+\n", |
| "|age| name|\n", |
| "+---+--------+\n", |
| "| 30| John D.|\n", |
| "| 25|Alice G.|\n", |
| "| 35| Bob T.|\n", |
| "| 28| Eve A.|\n", |
| "+---+--------+\n", |
| "\n" |
| ] |
| } |
| ], |
| "source": [ |
| "from pyspark.sql.functions import col, regexp_replace\n", |
| "\n", |
| "# Remove additional spaces in name\n", |
| "def remove_extra_spaces(df, column_name):\n", |
| " # Remove extra spaces from the specified column\n", |
| " df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), \"\\\\s+\", \" \"))\n", |
| " \n", |
| " return df_transformed\n", |
| "\n", |
| "transformed_df = remove_extra_spaces(df, \"name\")\n", |
| "\n", |
| "transformed_df.show()" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "530beaa6-aabf-43a1-ad2b-361f267e9608", |
| "metadata": {}, |
| "source": [ |
| "## Testing your PySpark Application\n", |
| "Now let’s test our PySpark transformation function. \n", |
| "\n", |
| "One option is to simply eyeball the resulting DataFrame. However, this can be impractical for large DataFrame or input sizes.\n", |
| "\n", |
| "A better way is to write tests. Here are some examples of how we can test our code. The examples below apply for Spark 3.5 and above versions.\n", |
| "\n", |
| "Note that these examples are not exhaustive, as there are many other test framework alternatives which you can use instead of `unittest` or `pytest`. The built-in PySpark testing util functions are standalone, meaning they can be compatible with any test framework or CI test pipeline.\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "d84a9fc1-9768-4af4-bfbf-e832f23334dc", |
| "metadata": {}, |
| "source": [ |
| "### Option 1: Using Only PySpark Built-in Test Utility Functions\n", |
| "\n", |
| "For simple ad-hoc validation cases, PySpark testing utils like `assertDataFrameEqual` and `assertSchemaEqual` can be used in a standalone context.\n", |
| "You could easily test PySpark code in a notebook session. For example, say you want to assert equality between two DataFrames:\n", |
| "\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 10, |
| "id": "8e533732-ee40-4cd0-9669-8eb92973908a", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import pyspark.testing\n", |
| "from pyspark.testing.utils import assertDataFrameEqual\n", |
| "\n", |
| "# Example 1\n", |
| "df1 = spark.createDataFrame(data=[(\"1\", 1000), (\"2\", 3000)], schema=[\"id\", \"amount\"])\n", |
| "df2 = spark.createDataFrame(data=[(\"1\", 1000), (\"2\", 3000)], schema=[\"id\", \"amount\"])\n", |
| "assertDataFrameEqual(df1, df2) # pass, DataFrames are identical" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 11, |
| "id": "2d77a6be-1e50-4c1a-8a44-85cf7dcec3f3", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Example 2\n", |
| "df1 = spark.createDataFrame(data=[(\"1\", 0.1), (\"2\", 3.23)], schema=[\"id\", \"amount\"])\n", |
| "df2 = spark.createDataFrame(data=[(\"1\", 0.109), (\"2\", 3.23)], schema=[\"id\", \"amount\"])\n", |
| "assertDataFrameEqual(df1, df2, rtol=1e-1) # pass, DataFrames are approx equal by rtol" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "76ade5f2-4a1f-4601-9d2a-80da9da950ff", |
| "metadata": {}, |
| "source": [ |
| "You can also simply compare two DataFrame schemas:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 13, |
| "id": "74393af5-40fb-4d04-87cb-265971ffe6d0", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "from pyspark.testing.utils import assertSchemaEqual\n", |
| "from pyspark.sql.types import StructType, StructField, ArrayType, DoubleType\n", |
| "\n", |
| "s1 = StructType([StructField(\"names\", ArrayType(DoubleType(), True), True)])\n", |
| "s2 = StructType([StructField(\"names\", ArrayType(DoubleType(), True), True)])\n", |
| "\n", |
| "assertSchemaEqual(s1, s2) # pass, schemas are identical" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "c67be105-f6b1-4083-ad11-9e819331eae8", |
| "metadata": {}, |
| "source": [ |
| "### Option 2: Using [Unit Test](https://docs.python.org/3/library/unittest.html)\n", |
| "For more complex testing scenarios, you may want to use a testing framework.\n", |
| "\n", |
| "One of the most popular testing framework options is unit tests. Let’s walk through how you can use the built-in Python `unittest` library to write PySpark tests. For more information about the `unittest` library, see here: https://docs.python.org/3/library/unittest.html. \n", |
| "\n", |
| "First, you will need a Spark session. You can use the `@classmethod` decorator from the `unittest` package to take care of setting up and tearing down a Spark session." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 15, |
| "id": "54093761-0b49-4aee-baec-2d29bcf13f9f", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import unittest\n", |
| "\n", |
| "class PySparkTestCase(unittest.TestCase):\n", |
| " @classmethod\n", |
| " def setUpClass(cls):\n", |
| " cls.spark = SparkSession.builder.appName(\"Testing PySpark Example\").getOrCreate() \n", |
| "\n", |
| " \n", |
| " @classmethod\n", |
| " def tearDownClass(cls):\n", |
| " cls.spark.stop()" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "3de27500-8526-412e-bf09-6927a760c5d7", |
| "metadata": {}, |
| "source": [ |
| "Now let’s write a `unittest` class." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 17, |
| "id": "34feb5e1-944f-4f6b-9c5f-3b0bf68c7d05", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "from pyspark.testing.utils import assertDataFrameEqual\n", |
| "\n", |
| "class TestTranformation(PySparkTestCase):\n", |
| " def test_single_space(self):\n", |
| " sample_data = [{\"name\": \"John D.\", \"age\": 30}, \n", |
| " {\"name\": \"Alice G.\", \"age\": 25}, \n", |
| " {\"name\": \"Bob T.\", \"age\": 35}, \n", |
| " {\"name\": \"Eve A.\", \"age\": 28}] \n", |
| " \n", |
| " # Create a Spark DataFrame\n", |
| " original_df = spark.createDataFrame(sample_data)\n", |
| " \n", |
| " # Apply the transformation function from before\n", |
| " transformed_df = remove_extra_spaces(original_df, \"name\")\n", |
| " \n", |
| " expected_data = [{\"name\": \"John D.\", \"age\": 30}, \n", |
| " {\"name\": \"Alice G.\", \"age\": 25}, \n", |
| " {\"name\": \"Bob T.\", \"age\": 35}, \n", |
| " {\"name\": \"Eve A.\", \"age\": 28}]\n", |
| " \n", |
| " expected_df = spark.createDataFrame(expected_data)\n", |
| " \n", |
| " assertDataFrameEqual(transformed_df, expected_df)\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "319a690f-71bd-4886-bd3a-424e866525c2", |
| "metadata": {}, |
| "source": [ |
| "When run, `unittest` will pick up all functions with a name beginning with “test.”" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "7d79e53d-cc1e-4fdf-a069-478337bed83d", |
| "metadata": {}, |
| "source": [ |
| "### Option 3: Using [Pytest](https://docs.pytest.org/en/7.1.x/contents.html)\n", |
| "\n", |
| "We can also write our tests with `pytest`, which is one of the most popular Python testing frameworks. For more information about `pytest`, see the docs here: https://docs.pytest.org/en/7.1.x/contents.html.\n", |
| "\n", |
| "Using a `pytest` fixture allows us to share a spark session across tests, tearing it down when the tests are complete." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 20, |
| "id": "60a4f304-1911-4b4d-8ed9-00ecc8b0890b", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import pytest\n", |
| "\n", |
| "@pytest.fixture\n", |
| "def spark_fixture():\n", |
| " spark = SparkSession.builder.appName(\"Testing PySpark Example\").getOrCreate()\n", |
| " yield spark" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "fcb4e26a-9bfc-48a5-8aca-538697d66642", |
| "metadata": {}, |
| "source": [ |
| "We can then define our tests like this:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 22, |
| "id": "fa5db3a1-7305-44b7-ab84-f5ed55fd2ba9", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import pytest\n", |
| "from pyspark.testing.utils import assertDataFrameEqual\n", |
| "\n", |
| "def test_single_space(spark_fixture):\n", |
| " sample_data = [{\"name\": \"John D.\", \"age\": 30}, \n", |
| " {\"name\": \"Alice G.\", \"age\": 25}, \n", |
| " {\"name\": \"Bob T.\", \"age\": 35}, \n", |
| " {\"name\": \"Eve A.\", \"age\": 28}] \n", |
| " \n", |
| " # Create a Spark DataFrame\n", |
| " original_df = spark.createDataFrame(sample_data)\n", |
| " \n", |
| " # Apply the transformation function from before\n", |
| " transformed_df = remove_extra_spaces(original_df, \"name\")\n", |
| " \n", |
| " expected_data = [{\"name\": \"John D.\", \"age\": 30}, \n", |
| " {\"name\": \"Alice G.\", \"age\": 25}, \n", |
| " {\"name\": \"Bob T.\", \"age\": 35}, \n", |
| " {\"name\": \"Eve A.\", \"age\": 28}]\n", |
| " \n", |
| " expected_df = spark.createDataFrame(expected_data)\n", |
| "\n", |
| " assertDataFrameEqual(transformed_df, expected_df)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "0fc3f394-3260-4e42-82cf-1a7edc859151", |
| "metadata": {}, |
| "source": [ |
| "When you run your test file with the `pytest` command, it will pick up all functions that have their name beginning with “test.”" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "d8f50eee-5d0b-4719-b505-1b3ff05c16e8", |
| "metadata": {}, |
| "source": [ |
| "## Putting It All Together!\n", |
| "\n", |
| "Let’s see all the steps together, in a Unit Test example." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 25, |
| "id": "a2ea9dec-0ac0-4c23-8770-d6cc226d2e97", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# pkg/etl.py\n", |
| "import unittest\n", |
| "\n", |
| "from pyspark.sql import SparkSession \n", |
| "from pyspark.sql.functions import col\n", |
| "from pyspark.sql.functions import regexp_replace\n", |
| "from pyspark.testing.utils import assertDataFrameEqual\n", |
| "\n", |
| "# Create a SparkSession \n", |
| "spark = SparkSession.builder.appName(\"Sample PySpark ETL\").getOrCreate() \n", |
| "\n", |
| "sample_data = [{\"name\": \"John D.\", \"age\": 30}, \n", |
| " {\"name\": \"Alice G.\", \"age\": 25}, \n", |
| " {\"name\": \"Bob T.\", \"age\": 35}, \n", |
| " {\"name\": \"Eve A.\", \"age\": 28}] \n", |
| "\n", |
| "df = spark.createDataFrame(sample_data)\n", |
| "\n", |
| "# Define DataFrame transformation function\n", |
| "def remove_extra_spaces(df, column_name):\n", |
| " # Remove extra spaces from the specified column using regexp_replace\n", |
| " df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), \"\\\\s+\", \" \"))\n", |
| "\n", |
| " return df_transformed" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 26, |
| "id": "248aede2-feb9-4828-bd9c-8e25e6b194ab", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# pkg/test_etl.py\n", |
| "import unittest\n", |
| "\n", |
| "from pyspark.sql import SparkSession \n", |
| "\n", |
| "# Define unit test base class\n", |
| "class PySparkTestCase(unittest.TestCase):\n", |
| " @classmethod\n", |
| " def setUpClass(cls):\n", |
| " cls.spark = SparkSession.builder.appName(\"Sample PySpark ETL\").getOrCreate() \n", |
| "\n", |
| " @classmethod\n", |
| " def tearDownClass(cls):\n", |
| " cls.spark.stop()\n", |
| " \n", |
| "# Define unit test\n", |
| "class TestTranformation(PySparkTestCase):\n", |
| " def test_single_space(self):\n", |
| " sample_data = [{\"name\": \"John D.\", \"age\": 30}, \n", |
| " {\"name\": \"Alice G.\", \"age\": 25}, \n", |
| " {\"name\": \"Bob T.\", \"age\": 35}, \n", |
| " {\"name\": \"Eve A.\", \"age\": 28}] \n", |
| " \n", |
| " # Create a Spark DataFrame\n", |
| " original_df = spark.createDataFrame(sample_data)\n", |
| " \n", |
| " # Apply the transformation function from before\n", |
| " transformed_df = remove_extra_spaces(original_df, \"name\")\n", |
| " \n", |
| " expected_data = [{\"name\": \"John D.\", \"age\": 30}, \n", |
| " {\"name\": \"Alice G.\", \"age\": 25}, \n", |
| " {\"name\": \"Bob T.\", \"age\": 35}, \n", |
| " {\"name\": \"Eve A.\", \"age\": 28}]\n", |
| " \n", |
| " expected_df = spark.createDataFrame(expected_data)\n", |
| " \n", |
| " assertDataFrameEqual(transformed_df, expected_df)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 27, |
| "id": "a77df5b2-f32e-4d8c-a64b-0078dfa21217", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Ran 1 test in 1.734s\n", |
| "\n", |
| "OK\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/plain": [ |
| "<unittest.main.TestProgram at 0x174539db0>" |
| ] |
| }, |
| "execution_count": 27, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "unittest.main(argv=[''], verbosity=0, exit=False)" |
| ] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "jupyter-oss-env", |
| "language": "python", |
| "name": "jupyter-oss-env" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.10.9" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |