blob: f18add2201db8bda5b4925f69043f7c7ef4c249e [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Optional
import os
import time
try:
import numpy as np
have_numpy = True
except ImportError:
have_numpy = False
try:
import pandas as pd
have_pandas = True
except ImportError:
have_pandas = False
try:
import pyarrow as pa
have_pyarrow = True
except ImportError:
have_pyarrow = False
# PyArrow uses internal names ("halffloat", "float", "double") that differ from
# the commonly used names ("float16", "float32", "float64"). This mapping
# normalises the str() representation of Arrow DataType so that repr_type()
# returns the more intuitive names.
_ARROW_FLOAT_ALIASES = {
"halffloat": "float16",
"float": "float32",
"double": "float64",
}
class GoldenFileTestMixin:
"""
Mixin class providing utilities for golden file based testing.
Golden files are CSV files that store expected test results. This mixin provides:
- Timezone setup/teardown for deterministic results
- Golden file read/write with SPARK_GENERATE_GOLDEN_FILES env var support
- Result string cleaning utilities
To regenerate golden files, set SPARK_GENERATE_GOLDEN_FILES=1 before running tests.
Usage:
class MyTest(GoldenFileTestMixin, ReusedSQLTestCase):
def test_something(self):
# Use helper methods from mixin
if self.is_generating_golden():
self.save_golden(df, golden_csv, golden_md)
else:
golden = self.load_golden_csv(golden_csv)
# compare results with golden
"""
_tz_prev: Optional[str] = None
def __init_subclass__(cls, **kwargs):
"""Verify correct inheritance order at class definition time."""
super().__init_subclass__(**kwargs)
# Check that GoldenFileTestMixin comes before any class with setUpClass in MRO.
# This ensures setup_timezone() will be called after Spark session is created.
# Correct: class MyTest(GoldenFileTestMixin, ReusedSQLTestCase)
# Incorrect: class MyTest(ReusedSQLTestCase, GoldenFileTestMixin)
for base in cls.__mro__:
if base is GoldenFileTestMixin:
break
# If we find a class with setUpClass before GoldenFileTestMixin, that's wrong
if base is not cls and hasattr(base, "setUpClass") and "setUpClass" in base.__dict__:
raise TypeError(
f"{cls.__name__} has incorrect inheritance order. "
f"GoldenFileTestMixin must be listed BEFORE {base.__name__}. "
f"Use: class {cls.__name__}(GoldenFileTestMixin, {base.__name__}, ...)"
)
@classmethod
def setUpClass(cls) -> None:
"""Setup test class with timezone configuration."""
super().setUpClass()
cls.setup_timezone()
@classmethod
def tearDownClass(cls) -> None:
"""Teardown test class and restore timezone."""
cls.teardown_timezone()
super().tearDownClass()
@classmethod
def setup_timezone(cls, tz: str = "America/Los_Angeles") -> None:
"""
Setup timezone for deterministic test results.
Sets the OS-level TZ environment variable and, when a Spark session
is available, synchronises the timezone with the JVM and Spark config.
This allows the mixin to be used with both ReusedSQLTestCase (Spark)
and plain unittest.TestCase (no Spark).
"""
cls._tz_prev = os.environ.get("TZ", None)
os.environ["TZ"] = tz
time.tzset()
# Sync with Spark / Java if a session is available.
if hasattr(cls, "sc"):
cls.sc.environment["TZ"] = tz
if hasattr(cls, "spark"):
cls.spark.conf.set("spark.sql.session.timeZone", tz)
@classmethod
def teardown_timezone(cls) -> None:
"""Restore original timezone."""
if "TZ" in os.environ:
del os.environ["TZ"]
if cls._tz_prev is not None:
os.environ["TZ"] = cls._tz_prev
time.tzset()
@staticmethod
def is_generating_golden() -> bool:
"""Check if we are generating golden files (vs testing against them)."""
return os.environ.get("SPARK_GENERATE_GOLDEN_FILES", "0") == "1"
@staticmethod
def load_golden_csv(golden_csv: str, use_index: bool = True) -> "pd.DataFrame":
"""
Load golden file from CSV.
Parameters
----------
golden_csv : str
Path to the golden CSV file.
use_index : bool
If True, use first column as index.
If False, don't use index.
Returns
-------
pd.DataFrame
The loaded golden data with string dtype.
"""
return pd.read_csv(
golden_csv,
sep="\t",
index_col=0 if use_index else None,
dtype="str",
na_filter=False,
engine="python",
)
@staticmethod
def save_golden(df: "pd.DataFrame", golden_csv: str, golden_md: Optional[str] = None) -> None:
"""
Save DataFrame as golden file (CSV and optionally Markdown).
Parameters
----------
df : pd.DataFrame
The DataFrame to save.
golden_csv : str
Path to save the CSV file.
golden_md : str, optional
Path to save the Markdown file. Requires tabulate package.
"""
df.to_csv(golden_csv, sep="\t", header=True, index=True)
if golden_md is not None:
try:
df.to_markdown(golden_md, index=True, tablefmt="github")
except Exception as e:
import warnings
warnings.warn(
f"Failed to write markdown file {golden_md}: {e}. "
"Install 'tabulate' package to generate markdown files."
)
@staticmethod
def repr_type(t: Any) -> str:
"""
Convert a type to a readable string representation.
Handles different type representations:
- Spark DataType: uses simpleString()
(e.g. "int", "string", "array<int>")
- PyArrow DataType: uses str(t) with float-name normalisation
(e.g. "int8", "float32", "timestamp[s, tz=UTC]")
- Python type: uses __name__
(e.g. "int", "str", "list")
- Other: falls back to str(t)
Parameters
----------
t : Any
The type to represent.
Returns
-------
str
Human-readable string representation of the type.
"""
# Spark DataType
if hasattr(t, "simpleString"):
return t.simpleString()
# Python type (class)
elif isinstance(t, type):
return t.__name__
else:
s = str(t)
# Normalise PyArrow float type names to be more intuitive:
# "halffloat" -> "float16", "float" -> "float32", "double" -> "float64"
return _ARROW_FLOAT_ALIASES.get(s, s)
@classmethod
def repr_arrow_value(cls, value: Any, max_len: int = 32) -> str:
"""
Format a PyArrow Array/ChunkedArray for golden file.
Each element uses str(scalar) from PyArrow's own scalar formatting.
Parameters
----------
value : pa.Array or pa.ChunkedArray
The PyArrow array to represent.
max_len : int, default 32
Maximum length for the value string portion. 0 means no limit.
Returns
-------
str
"[val1, val2, None]@arrow_type"
"""
# Escape NULL bytes so the value can be safely stored in CSV files.
elements = [str(scalar).replace("\x00", "\\0") for scalar in value]
v_str = "[" + ", ".join(elements) + "]"
if max_len > 0:
v_str = v_str[:max_len]
return f"{v_str}@{cls.repr_type(value.type)}"
@classmethod
def repr_pandas_value(cls, value: Any, max_len: int = 32) -> str:
"""
Format a pandas DataFrame for golden file.
Parameters
----------
value : pd.DataFrame
The pandas DataFrame to represent.
max_len : int, default 32
Maximum length for the value string portion. 0 means no limit.
Returns
-------
str
"value@Dataframe[schema]"
"""
v_str = value.to_json().replace("\n", " ")
if max_len > 0:
v_str = v_str[:max_len]
simple_schema = ", ".join([f"{t} {d.name}" for t, d in value.dtypes.items()])
return f"{v_str}@Dataframe[{simple_schema}]"
@classmethod
def repr_numpy_value(cls, value: Any, max_len: int = 32) -> str:
"""
Format a numpy ndarray for golden file.
Parameters
----------
value : np.ndarray
The numpy ndarray to represent.
max_len : int, default 32
Maximum length for the value string portion. 0 means no limit.
Returns
-------
str
"value@ndarray[dtype]"
"""
v_str = str(value).replace("\n", " ")
if max_len > 0:
v_str = v_str[:max_len]
return f"{v_str}@ndarray[{value.dtype.name}]"
@classmethod
def repr_python_value(cls, value: Any, max_len: int = 32) -> str:
"""
Format a plain Python value for golden file.
Returns
-------
str
"str(value)@class_name"
"""
v_str = str(value).replace("\n", " ")
if max_len > 0:
v_str = v_str[:max_len]
return f"{v_str}@{type(value).__name__}"
@classmethod
def repr_value(cls, value: Any, max_len: int = 32) -> str:
"""
Format a value for golden file, dispatching to the appropriate repr
based on the value's type.
- PyArrow Array/ChunkedArray -> repr_arrow_value
- pandas DataFrame -> repr_pandas_value
- numpy ndarray -> repr_numpy_value
- Everything else -> repr_python_value
Parameters
----------
value : Any
The value to represent.
max_len : int, default 32
Maximum length for the value string portion. 0 means no limit.
Returns
-------
str
String representation in format "value@type".
"""
if have_pyarrow and isinstance(value, (pa.Array, pa.ChunkedArray)):
return cls.repr_arrow_value(value, max_len)
if have_pandas and isinstance(value, pd.DataFrame):
return cls.repr_pandas_value(value, max_len)
if have_numpy and isinstance(value, np.ndarray):
return cls.repr_numpy_value(value, max_len)
return cls.repr_python_value(value, max_len)
@staticmethod
def clean_result(result: str) -> str:
"""Clean result string by removing newlines and extra whitespace."""
return result.replace("\n", " ").replace("\r", " ").replace("\t", " ")