blob: fbc6d3a90ac8eb501bf7c96bab4bccbaf9468268 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
import tempfile
import textwrap
from pathlib import Path
from pyspark.errors import PySparkException
from pyspark.testing.connectutils import (
should_test_connect,
connect_requirement_message,
)
from pyspark.testing.utils import have_yaml, yaml_requirement_message
if should_test_connect and have_yaml:
from pyspark.pipelines.cli import (
change_dir,
find_pipeline_spec,
load_pipeline_spec,
register_definitions,
unpack_pipeline_spec,
LibrariesGlob,
PipelineSpec,
run,
)
from pyspark.pipelines.tests.local_graph_element_registry import LocalGraphElementRegistry
@unittest.skipIf(
not should_test_connect or not have_yaml,
connect_requirement_message or yaml_requirement_message,
)
class CLIUtilityTests(unittest.TestCase):
def test_load_pipeline_spec(self):
with tempfile.NamedTemporaryFile(mode="w") as tmpfile:
tmpfile.write(
"""
{
"name": "test_pipeline",
"catalog": "test_catalog",
"database": "test_database",
"configuration": {
"key1": "value1",
"key2": "value2"
},
"libraries": [
{"glob": {"include": "test_include"}}
]
}
"""
)
tmpfile.flush()
spec = load_pipeline_spec(Path(tmpfile.name))
assert spec.name == "test_pipeline"
assert spec.catalog == "test_catalog"
assert spec.database == "test_database"
assert spec.configuration == {"key1": "value1", "key2": "value2"}
assert len(spec.libraries) == 1
assert spec.libraries[0].include == "test_include"
def test_load_pipeline_spec_name_is_required(self):
with tempfile.NamedTemporaryFile(mode="w") as tmpfile:
tmpfile.write(
"""
{
"catalog": "test_catalog",
"database": "test_database",
"configuration": {
"key1": "value1",
"key2": "value2"
},
"libraries": [
{"glob": {"include": "test_include"}}
]
}
"""
)
tmpfile.flush()
with self.assertRaises(PySparkException) as context:
load_pipeline_spec(Path(tmpfile.name))
self.assertEqual(
context.exception.getCondition(), "PIPELINE_SPEC_MISSING_REQUIRED_FIELD"
)
self.assertEqual(context.exception.getMessageParameters(), {"field_name": "name"})
def test_load_pipeline_spec_schema_fallback(self):
with tempfile.NamedTemporaryFile(mode="w") as tmpfile:
tmpfile.write(
"""
{
"name": "test_pipeline",
"catalog": "test_catalog",
"schema": "test_database",
"configuration": {
"key1": "value1",
"key2": "value2"
},
"libraries": [
{"glob": {"include": "test_include"}}
]
}
"""
)
tmpfile.flush()
spec = load_pipeline_spec(Path(tmpfile.name))
assert spec.catalog == "test_catalog"
assert spec.database == "test_database"
assert spec.configuration == {"key1": "value1", "key2": "value2"}
assert len(spec.libraries) == 1
assert spec.libraries[0].include == "test_include"
def test_load_pipeline_spec_invalid(self):
with tempfile.NamedTemporaryFile(mode="w") as tmpfile:
tmpfile.write(
"""
{
"catalogtypo": "test_catalog",
"configuration": {
"key1": "value1",
"key2": "value2"
},
"libraries": [
{"glob": {"include": "test_include"}}
]
}
"""
)
tmpfile.flush()
with self.assertRaises(PySparkException) as context:
load_pipeline_spec(Path(tmpfile.name))
self.assertEqual(context.exception.getCondition(), "PIPELINE_SPEC_UNEXPECTED_FIELD")
self.assertEqual(
context.exception.getMessageParameters(), {"field_name": "catalogtypo"}
)
def test_unpack_empty_pipeline_spec(self):
empty_spec = PipelineSpec(
name="test_pipeline", catalog=None, database=None, configuration={}, libraries=[]
)
self.assertEqual(unpack_pipeline_spec({"name": "test_pipeline"}), empty_spec)
def test_unpack_pipeline_spec_bad_configuration(self):
with self.assertRaises(TypeError) as context:
unpack_pipeline_spec({"name": "test_pipeline", "configuration": "not_a_dict"})
self.assertIn("should be a dict", str(context.exception))
with self.assertRaises(TypeError) as context:
unpack_pipeline_spec({"name": "test_pipeline", "configuration": {"key": {}}})
self.assertIn("key", str(context.exception))
with self.assertRaises(TypeError) as context:
unpack_pipeline_spec({"name": "test_pipeline", "configuration": {1: "something"}})
self.assertIn("int", str(context.exception))
def test_find_pipeline_spec_in_current_directory(self):
with tempfile.TemporaryDirectory() as temp_dir:
spec_path = Path(temp_dir) / "pipeline.yaml"
with spec_path.open("w") as f:
f.write(
"""
{
"catalog": "test_catalog",
"configuration": {},
"libraries": []
}
"""
)
found_spec = find_pipeline_spec(Path(temp_dir))
self.assertEqual(found_spec, spec_path)
def test_find_pipeline_spec_in_current_directory_yml(self):
with tempfile.TemporaryDirectory() as temp_dir:
spec_path = Path(temp_dir) / "pipeline.yml"
with spec_path.open("w") as f:
f.write(
"""
{
"catalog": "test_catalog",
"configuration": {},
"libraries": []
}
"""
)
found_spec = find_pipeline_spec(Path(temp_dir))
self.assertEqual(found_spec, spec_path)
def test_find_pipeline_spec_in_current_directory_yml_and_yaml(self):
with tempfile.TemporaryDirectory() as temp_dir:
with (Path(temp_dir) / "pipeline.yml").open("w") as f:
f.write("")
with (Path(temp_dir) / "pipeline.yaml").open("w") as f:
f.write("")
with self.assertRaises(PySparkException) as context:
find_pipeline_spec(Path(temp_dir))
self.assertEqual(context.exception.getCondition(), "MULTIPLE_PIPELINE_SPEC_FILES_FOUND")
self.assertEqual(context.exception.getMessageParameters(), {"dir_path": temp_dir})
def test_find_pipeline_spec_in_parent_directory(self):
with tempfile.TemporaryDirectory() as temp_dir:
parent_dir = Path(temp_dir)
child_dir = Path(temp_dir) / "child"
child_dir.mkdir()
spec_path = parent_dir / "pipeline.yaml"
with spec_path.open("w") as f:
f.write(
"""
{
"catalog": "test_catalog",
"configuration": {},
"libraries": []
}
"""
)
found_spec = find_pipeline_spec(Path(child_dir))
self.assertEqual(found_spec, spec_path)
def test_register_definitions(self):
spec = PipelineSpec(
name="test_pipeline",
catalog=None,
database=None,
configuration={},
libraries=[LibrariesGlob(include="subdir1/**")],
)
with tempfile.TemporaryDirectory() as temp_dir:
outer_dir = Path(temp_dir)
subdir1 = outer_dir / "subdir1"
subdir1.mkdir()
subdir2 = outer_dir / "subdir2"
subdir2.mkdir()
with (subdir1 / "libraries.py").open("w") as f:
f.write(
textwrap.dedent(
"""
from pyspark import pipelines as dp
@dp.materialized_view
def mv1():
raise NotImplementedError()
"""
)
)
with (subdir2 / "libraries.py").open("w") as f:
f.write(
textwrap.dedent(
"""
from pyspark import pipelines as dp
def mv2():
raise NotImplementedError()
"""
)
)
registry = LocalGraphElementRegistry()
register_definitions(outer_dir / "pipeline.yaml", registry, spec)
self.assertEqual(len(registry.datasets), 1)
self.assertEqual(registry.datasets[0].name, "mv1")
def test_register_definitions_file_raises_error(self):
"""Errors raised while executing definitions code should make it to the outer context."""
spec = PipelineSpec(
name="test_pipeline",
catalog=None,
database=None,
configuration={},
libraries=[LibrariesGlob(include="./**")],
)
with tempfile.TemporaryDirectory() as temp_dir:
outer_dir = Path(temp_dir)
with (outer_dir / "definitions.py").open("w") as f:
f.write("raise RuntimeError('This is a test exception')")
registry = LocalGraphElementRegistry()
with self.assertRaises(RuntimeError) as context:
register_definitions(outer_dir / "pipeline.yml", registry, spec)
self.assertIn("This is a test exception", str(context.exception))
def test_register_definitions_unsupported_file_extension_matches_glob(self):
spec = PipelineSpec(
name="test_pipeline",
catalog=None,
database=None,
configuration={},
libraries=[LibrariesGlob(include="./**")],
)
with tempfile.TemporaryDirectory() as temp_dir:
outer_dir = Path(temp_dir)
with (outer_dir / "definitions.java").open("w") as f:
f.write("")
registry = LocalGraphElementRegistry()
with self.assertRaises(PySparkException) as context:
register_definitions(outer_dir, registry, spec)
self.assertEqual(
context.exception.getCondition(), "PIPELINE_UNSUPPORTED_DEFINITIONS_FILE_EXTENSION"
)
def test_python_import_current_directory(self):
"""Tests that the Python system path is resolved relative to the dir containing the pipeline
spec file."""
with tempfile.TemporaryDirectory() as temp_dir:
outer_dir = Path(temp_dir)
inner_dir1 = outer_dir / "inner1"
inner_dir1.mkdir()
inner_dir2 = outer_dir / "inner2"
inner_dir2.mkdir()
with (inner_dir1 / "defs.py").open("w") as f:
f.write(
textwrap.dedent(
"""
import sys
sys.path.append(".")
import mypackage.my_module
"""
)
)
inner_dir1_mypackage = inner_dir1 / "mypackage"
inner_dir1_mypackage.mkdir()
with (inner_dir1_mypackage / "__init__.py").open("w") as f:
f.write("")
with (inner_dir1_mypackage / "my_module.py").open("w") as f:
f.write("")
registry = LocalGraphElementRegistry()
with change_dir(inner_dir2):
register_definitions(
inner_dir1 / "pipeline.yaml",
registry,
PipelineSpec(
name="test_pipeline",
catalog=None,
database=None,
configuration={},
libraries=[LibrariesGlob(include="defs.py")],
),
)
def test_full_refresh_all_conflicts_with_full_refresh(self):
with tempfile.TemporaryDirectory() as temp_dir:
# Create a minimal pipeline spec
spec_path = Path(temp_dir) / "pipeline.yaml"
with spec_path.open("w") as f:
f.write('{"name": "test_pipeline"}')
# Test that providing both --full-refresh-all and --full-refresh raises an exception
with self.assertRaises(PySparkException) as context:
run(
spec_path=spec_path,
full_refresh=["table1", "table2"],
full_refresh_all=True,
refresh=[],
dry=False,
)
self.assertEqual(
context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS"
)
self.assertEqual(
context.exception.getMessageParameters(), {"conflicting_option": "--full_refresh"}
)
def test_full_refresh_all_conflicts_with_refresh(self):
with tempfile.TemporaryDirectory() as temp_dir:
# Create a minimal pipeline spec
spec_path = Path(temp_dir) / "pipeline.yaml"
with spec_path.open("w") as f:
f.write('{"name": "test_pipeline"}')
# Test that providing both --full-refresh-all and --refresh raises an exception
with self.assertRaises(PySparkException) as context:
run(
spec_path=spec_path,
full_refresh=[],
full_refresh_all=True,
refresh=["table1", "table2"],
dry=False,
)
self.assertEqual(
context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS"
)
self.assertEqual(
context.exception.getMessageParameters(),
{"conflicting_option": "--refresh"},
)
def test_full_refresh_all_conflicts_with_both(self):
with tempfile.TemporaryDirectory() as temp_dir:
# Create a minimal pipeline spec
spec_path = Path(temp_dir) / "pipeline.yaml"
with spec_path.open("w") as f:
f.write('{"name": "test_pipeline"}')
# Test that providing --full-refresh-all with both other options raises an exception
# (it should catch the first conflict - full_refresh)
with self.assertRaises(PySparkException) as context:
run(
spec_path=spec_path,
full_refresh=["table1"],
full_refresh_all=True,
refresh=["table2"],
dry=False,
)
self.assertEqual(
context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS"
)
def test_parse_table_list_single_table(self):
"""Test parsing a single table name."""
from pyspark.pipelines.cli import parse_table_list
result = parse_table_list("table1")
self.assertEqual(result, ["table1"])
def test_parse_table_list_multiple_tables(self):
"""Test parsing multiple table names."""
from pyspark.pipelines.cli import parse_table_list
result = parse_table_list("table1,table2,table3")
self.assertEqual(result, ["table1", "table2", "table3"])
def test_parse_table_list_with_spaces(self):
"""Test parsing table names with spaces."""
from pyspark.pipelines.cli import parse_table_list
result = parse_table_list("table1, table2 , table3")
self.assertEqual(result, ["table1", "table2", "table3"])
def test_valid_glob_patterns(self):
"""Test that valid glob patterns are accepted."""
from pyspark.pipelines.cli import validate_patch_glob_pattern
cases = {
# Simple file paths
"src/main.py": "src/main.py",
"data/file.sql": "data/file.sql",
# Folder paths ending with /** (normalized)
"src/**": "src/**/*",
"transformations/**": "transformations/**/*",
"notebooks/production/**": "notebooks/production/**/*",
}
for pattern, expected in cases.items():
with self.subTest(pattern=pattern):
self.assertEqual(validate_patch_glob_pattern(pattern), expected)
def test_invalid_glob_patterns(self):
"""Test that invalid glob patterns are rejected."""
from pyspark.pipelines.cli import validate_patch_glob_pattern
invalid_patterns = [
"transformations/**/*.py",
"src/**/utils/*.py",
"*/main.py",
"src/*/test/*.py",
"**/*.py",
"data/*/file.sql",
]
for pattern in invalid_patterns:
with self.subTest(pattern=pattern):
with self.assertRaises(PySparkException) as context:
validate_patch_glob_pattern(pattern)
self.assertEqual(
context.exception.getCondition(), "PIPELINE_SPEC_INVALID_GLOB_PATTERN"
)
self.assertEqual(
context.exception.getMessageParameters(), {"glob_pattern": pattern}
)
def test_pipeline_spec_with_invalid_glob_pattern(self):
"""Test that pipeline spec with invalid glob pattern is rejected."""
with tempfile.NamedTemporaryFile(mode="w") as tmpfile:
tmpfile.write(
"""
{
"name": "test_pipeline",
"libraries": [
{"glob": {"include": "transformations/**/*.py"}}
]
}
"""
)
tmpfile.flush()
with self.assertRaises(PySparkException) as context:
load_pipeline_spec(Path(tmpfile.name))
self.assertEqual(context.exception.getCondition(), "PIPELINE_SPEC_INVALID_GLOB_PATTERN")
self.assertEqual(
context.exception.getMessageParameters(),
{"glob_pattern": "transformations/**/*.py"},
)
if __name__ == "__main__":
try:
import xmlrunner # type: ignore
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)