| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| import time |
| import unittest |
| |
| from pyspark.sql import Row, Observation, functions as F |
| from pyspark.errors import ( |
| PySparkAssertionError, |
| PySparkTypeError, |
| PySparkValueError, |
| ) |
| from pyspark.testing.sqlutils import ReusedSQLTestCase |
| from pyspark.testing.utils import assertDataFrameEqual |
| |
| |
| class DataFrameObservationTestsMixin: |
| def test_observe(self): |
| # SPARK-36263: tests the DataFrame.observe(Observation, *Column) method |
| df = self.spark.createDataFrame( |
| [ |
| (1, 1.0, "one"), |
| (2, 2.0, "two"), |
| (3, 3.0, "three"), |
| ], |
| ["id", "val", "label"], |
| ) |
| |
| unnamed_observation = Observation() |
| named_observation = Observation("metric") |
| |
| with self.assertRaises(PySparkAssertionError) as pe: |
| unnamed_observation.get() |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NO_OBSERVE_BEFORE_GET", |
| messageParameters={}, |
| ) |
| |
| observed = ( |
| df.orderBy("id") |
| .observe( |
| named_observation, |
| F.count(F.lit(1)).alias("cnt"), |
| F.sum(F.col("id")).alias("sum"), |
| F.mean(F.col("val")).alias("mean"), |
| ) |
| .observe(unnamed_observation, F.count(F.lit(1)).alias("rows")) |
| ) |
| |
| # test that observe works transparently |
| actual = observed.collect() |
| self.assertEqual( |
| [ |
| {"id": 1, "val": 1.0, "label": "one"}, |
| {"id": 2, "val": 2.0, "label": "two"}, |
| {"id": 3, "val": 3.0, "label": "three"}, |
| ], |
| [row.asDict() for row in actual], |
| ) |
| |
| # test that we retrieve the metrics |
| self.assertEqual(named_observation.get, dict(cnt=3, sum=6, mean=2.0)) |
| self.assertEqual(unnamed_observation.get, dict(rows=3)) |
| |
| with self.assertRaises(PySparkAssertionError) as pe: |
| df.observe(named_observation, F.count(F.lit(1)).alias("count")) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="REUSE_OBSERVATION", |
| messageParameters={}, |
| ) |
| |
| # observation requires name (if given) to be non empty string |
| with self.assertRaisesRegex(TypeError, "`name` should be a str, got int"): |
| Observation(123) |
| with self.assertRaisesRegex(ValueError, "`name` must be a non-empty string, got ''."): |
| Observation("") |
| |
| # dataframe.observe requires at least one expr |
| with self.assertRaises(PySparkValueError) as pe: |
| df.observe(Observation()) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="CANNOT_BE_EMPTY", |
| messageParameters={"item": "exprs"}, |
| ) |
| |
| # dataframe.observe requires non-None Columns |
| for args in [(None,), ("id",), (F.lit(1), None), (F.lit(1), "id")]: |
| with self.subTest(args=args): |
| with self.assertRaises(PySparkTypeError) as pe: |
| df.observe(Observation(), *args) |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="NOT_LIST_OF_COLUMN", |
| messageParameters={"arg_name": "exprs"}, |
| ) |
| |
| def test_observe_str(self): |
| # SPARK-38760: tests the DataFrame.observe(str, *Column) method |
| from pyspark.sql.streaming import StreamingQueryListener |
| |
| observed_metrics = None |
| |
| class TestListener(StreamingQueryListener): |
| def onQueryStarted(self, event): |
| pass |
| |
| def onQueryProgress(self, event): |
| nonlocal observed_metrics |
| observed_metrics = event.progress.observedMetrics |
| |
| def onQueryIdle(self, event): |
| pass |
| |
| def onQueryTerminated(self, event): |
| pass |
| |
| self.spark.streams.addListener(TestListener()) |
| |
| df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() |
| df = df.observe( |
| "metric", F.count(F.lit(1)).alias("cnt"), F.sum(F.col("value")).alias("sum") |
| ) |
| q = df.writeStream.format("noop").queryName("test").start() |
| self.assertTrue(q.isActive) |
| time.sleep(10) |
| q.stop() |
| |
| self.assertTrue(isinstance(observed_metrics, dict)) |
| self.assertTrue("metric" in observed_metrics) |
| row = observed_metrics["metric"] |
| self.assertTrue(isinstance(row, Row)) |
| self.assertTrue(hasattr(row, "cnt")) |
| self.assertTrue(hasattr(row, "sum")) |
| self.assertGreaterEqual(row.cnt, 0) |
| self.assertGreaterEqual(row.sum, 0) |
| |
| def test_observe_with_same_name_on_different_dataframe(self): |
| # SPARK-45656: named observations with the same name on different datasets |
| observation1 = Observation("named") |
| df1 = self.spark.range(50) |
| observed_df1 = df1.observe(observation1, F.count(F.lit(1)).alias("cnt")) |
| |
| observation2 = Observation("named") |
| df2 = self.spark.range(100) |
| observed_df2 = df2.observe(observation2, F.count(F.lit(1)).alias("cnt")) |
| |
| observed_df1.collect() |
| observed_df2.collect() |
| |
| self.assertEqual(observation1.get, dict(cnt=50)) |
| self.assertEqual(observation2.get, dict(cnt=100)) |
| |
| def test_observe_on_commands(self): |
| df = self.spark.range(50) |
| |
| test_table = "test_table" |
| |
| # DataFrameWriter |
| with self.table(test_table): |
| for command, action in [ |
| ("collect", lambda df: df.collect()), |
| ("show", lambda df: df.show(50)), |
| ("save", lambda df: df.write.format("noop").mode("overwrite").save()), |
| ("create", lambda df: df.writeTo(test_table).using("parquet").create()), |
| ]: |
| with self.subTest(command=command): |
| observation = Observation() |
| observed_df = df.observe(observation, F.count(F.lit(1)).alias("cnt")) |
| action(observed_df) |
| self.assertEqual(observation.get, dict(cnt=50)) |
| |
| def test_observe_with_struct_type(self): |
| observation = Observation("struct") |
| |
| df = self.spark.range(10).observe( |
| observation, |
| F.struct(F.count(F.lit(1)).alias("rows"), F.max("id").alias("maxid")).alias("struct"), |
| ) |
| |
| assertDataFrameEqual(df, [Row(id=id) for id in range(10)]) |
| |
| self.assertEqual(observation.get, {"struct": Row(rows=10, maxid=9)}) |
| |
| def test_observe_with_array_type(self): |
| observation = Observation("array") |
| |
| df = self.spark.range(10).observe( |
| observation, |
| F.array(F.count(F.lit(1))).alias("array"), |
| ) |
| |
| assertDataFrameEqual(df, [Row(id=id) for id in range(10)]) |
| |
| self.assertEqual(observation.get, {"array": [10]}) |
| |
| def test_observe_with_map_type(self): |
| observation = Observation("map") |
| |
| df = self.spark.range(10).observe( |
| observation, |
| F.create_map(F.lit("count"), F.count(F.lit(1))).alias("map"), |
| ) |
| |
| assertDataFrameEqual(df, [Row(id=id) for id in range(10)]) |
| |
| self.assertEqual(observation.get, {"map": {"count": 10}}) |
| |
| |
| class DataFrameObservationTests( |
| DataFrameObservationTestsMixin, |
| ReusedSQLTestCase, |
| ): |
| pass |
| |
| |
| if __name__ == "__main__": |
| from pyspark.sql.tests.test_observation import * # noqa: F401 |
| |
| try: |
| import xmlrunner # type: ignore |
| |
| testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) |
| except ImportError: |
| testRunner = None |
| unittest.main(testRunner=testRunner, verbosity=2) |