| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| import unittest |
| from typing import cast |
| |
| from pyspark.errors import PySparkRuntimeError, PySparkTypeError |
| from pyspark.sql import Column |
| from pyspark.testing.connectutils import ( |
| should_test_connect, |
| connect_requirement_message, |
| ) |
| |
| if should_test_connect: |
| from pyspark import pipelines as dp |
| from pyspark.pipelines.graph_element_registry import graph_element_registration_context |
| from pyspark.pipelines.flow import AutoCdcFlow |
| from pyspark.pipelines.tests.local_graph_element_registry import LocalGraphElementRegistry |
| from pyspark.sql.connect.functions.builtin import col, expr |
| |
| |
| @unittest.skipIf(not should_test_connect, connect_requirement_message) |
| class AutoCdcFlowConstructionTest(unittest.TestCase): |
| def test_create_auto_cdc_flow(self): |
| registry = LocalGraphElementRegistry() |
| with graph_element_registration_context(registry): |
| dp.create_streaming_table("target") |
| dp.create_auto_cdc_flow( |
| target="target", |
| source="source", |
| keys=[col("key")], |
| sequence_by=expr("seq"), |
| ) |
| |
| self.assertEqual(len(registry.outputs), 1) |
| self.assertEqual(len(registry.auto_cdc_flows), 1) |
| |
| flow = cast(AutoCdcFlow, registry.auto_cdc_flows[0]) |
| self.assertEqual(flow.target, "target") |
| self.assertEqual(flow.source, "source") |
| |
| # When name is not specified, it inherits the target's name at construction time. |
| self.assertEqual(flow.name, "target") |
| self.assertIsNone(flow.stored_as_scd_type) |
| self.assertIsNone(flow.apply_as_deletes) |
| assert flow.source_code_location.filename.endswith("test_auto_cdc_flow.py") |
| |
| def test_create_auto_cdc_flow_with_all_args(self): |
| registry = LocalGraphElementRegistry() |
| with graph_element_registration_context(registry): |
| dp.create_streaming_table("tgt") |
| dp.create_auto_cdc_flow( |
| target="tgt", |
| source="src", |
| keys=[col("id")], |
| sequence_by=expr("ts"), |
| apply_as_deletes=expr("op = 'DELETE'"), |
| column_list=[col("id"), col("val")], |
| stored_as_scd_type=1, |
| name="my_flow", |
| ) |
| |
| flow = cast(AutoCdcFlow, registry.auto_cdc_flows[0]) |
| self.assertEqual(flow.name, "my_flow") |
| self.assertEqual(flow.stored_as_scd_type, 1) |
| |
| def test_create_auto_cdc_flow_with_string_args(self): |
| # Verify that string forms of column / expression arguments are normalized to |
| # PySpark Columns, equivalent to passing col(...) / expr(...) directly. |
| registry = LocalGraphElementRegistry() |
| with graph_element_registration_context(registry): |
| dp.create_streaming_table("tgt") |
| dp.create_auto_cdc_flow( |
| target="tgt", |
| source="src", |
| keys=["id"], |
| sequence_by="ts", |
| apply_as_deletes="op = 'DELETE'", |
| column_list=["id", "val"], |
| ) |
| |
| flow = cast(AutoCdcFlow, registry.auto_cdc_flows[0]) |
| for k in flow.keys: |
| self.assertIsInstance(k, Column) |
| self.assertIsInstance(flow.sequence_by, Column) |
| self.assertIsInstance(flow.apply_as_deletes, Column) |
| assert flow.column_list is not None |
| for c in flow.column_list: |
| self.assertIsInstance(c, Column) |
| |
| def test_create_auto_cdc_flow_stored_as_scd_type_string(self): |
| registry = LocalGraphElementRegistry() |
| with graph_element_registration_context(registry): |
| dp.create_streaming_table("t") |
| dp.create_auto_cdc_flow( |
| target="t", |
| source="s", |
| keys=[col("k")], |
| sequence_by=expr("seq"), |
| stored_as_scd_type="1", |
| ) |
| |
| flow = cast(AutoCdcFlow, registry.auto_cdc_flows[0]) |
| self.assertEqual(flow.stored_as_scd_type, "1") |
| |
| def test_create_auto_cdc_flow_invalid_scd_type(self): |
| registry = LocalGraphElementRegistry() |
| with graph_element_registration_context(registry): |
| dp.create_streaming_table("t") |
| with self.assertRaises(PySparkTypeError) as ctx: |
| dp.create_auto_cdc_flow( |
| target="t", |
| source="s", |
| keys=[col("k")], |
| sequence_by=expr("seq"), |
| stored_as_scd_type=2, # type: ignore[arg-type] |
| ) |
| self.assertEqual(ctx.exception.getCondition(), "NOT_EXPECTED_TYPE") |
| |
| def test_create_auto_cdc_flow_with_except_column_list(self): |
| registry = LocalGraphElementRegistry() |
| with graph_element_registration_context(registry): |
| dp.create_streaming_table("tgt") |
| dp.create_auto_cdc_flow( |
| target="tgt", |
| source="src", |
| keys=[col("id")], |
| sequence_by=expr("ts"), |
| except_column_list=["op", "ts"], |
| ) |
| |
| flow = cast(AutoCdcFlow, registry.auto_cdc_flows[0]) |
| self.assertIsNone(flow.column_list) |
| assert flow.except_column_list is not None |
| self.assertEqual(len(flow.except_column_list), 2) |
| for c in flow.except_column_list: |
| self.assertIsInstance(c, Column) |
| |
| def test_create_auto_cdc_flow_rejects_non_str_target(self): |
| registry = LocalGraphElementRegistry() |
| with graph_element_registration_context(registry): |
| dp.create_streaming_table("tgt") |
| with self.assertRaises(PySparkTypeError) as ctx: |
| dp.create_auto_cdc_flow( |
| target=123, # type: ignore[arg-type] |
| source="src", |
| keys=[col("id")], |
| sequence_by=expr("ts"), |
| ) |
| self.assertEqual(ctx.exception.getCondition(), "NOT_EXPECTED_TYPE") |
| |
| def test_create_auto_cdc_flow_rejects_invalid_key_element(self): |
| registry = LocalGraphElementRegistry() |
| with graph_element_registration_context(registry): |
| dp.create_streaming_table("tgt") |
| with self.assertRaises(PySparkTypeError) as ctx: |
| dp.create_auto_cdc_flow( |
| target="tgt", |
| source="src", |
| keys=[123], # type: ignore[list-item] |
| sequence_by=expr("ts"), |
| ) |
| self.assertEqual(ctx.exception.getCondition(), "NOT_EXPECTED_TYPE") |
| |
| def test_create_auto_cdc_flow_without_registry(self): |
| with self.assertRaises(PySparkRuntimeError) as context: |
| dp.create_auto_cdc_flow( |
| target="t", |
| source="s", |
| keys=["k"], |
| sequence_by="seq", |
| ) |
| |
| self.assertEqual( |
| context.exception.getCondition(), |
| "GRAPH_ELEMENT_DEFINED_OUTSIDE_OF_DECLARATIVE_PIPELINE", |
| ) |
| |
| |
| if __name__ == "__main__": |
| from pyspark.testing import main |
| |
| main() |