[MINOR][PYTHON][TESTS] Skip some tests if numpy not installed
### What changes were proposed in this pull request?
Skip some tests if numpy not installed
### Why are the changes needed?
these tests depends on numpy
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #52300 from zhengruifeng/test_skip_numpy.
Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
index 81a9c81..d49f341 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
@@ -30,11 +30,13 @@
)
from pyspark.sql import functions as sf
from pyspark.errors import AnalysisException, PythonException
-from pyspark.testing.sqlutils import (
- ReusedSQLTestCase,
+from pyspark.testing.utils import (
+ have_numpy,
+ numpy_requirement_message,
have_pyarrow,
pyarrow_requirement_message,
)
+from pyspark.testing.sqlutils import ReusedSQLTestCase
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
@@ -146,6 +148,7 @@
self.assertEqual(expected, result.collect())
+ @unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_basic(self):
df = self.data
weighted_mean_udf = self.arrow_agg_weighted_mean_udf
@@ -268,6 +271,7 @@
self.assertEqual(expected5, result5.collect())
self.assertEqual(expected6, result6.collect())
+ @unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_multiple_udfs(self):
"""
Test multiple group aggregate pandas UDFs in one agg function.
@@ -537,6 +541,7 @@
assert filtered.collect()[0]["mean"] == 42.0
+ @unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_named_arguments(self):
df = self.data
weighted_mean = self.arrow_agg_weighted_mean_udf
@@ -565,6 +570,7 @@
df.groupby("id").agg(sf.mean(df.v).alias("wm")).collect(),
)
+ @unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_named_arguments_negative(self):
df = self.data
weighted_mean = self.arrow_agg_weighted_mean_udf
@@ -886,6 +892,7 @@
# Integer value 2147483657 not in range: -2147483648 to 2147483647
result3.collect()
+ @unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_return_numpy_scalar(self):
import numpy as np
import pyarrow as pa
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
index d6e010d..3409ce9 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
@@ -46,11 +46,13 @@
YearMonthIntervalType,
)
from pyspark.errors import AnalysisException, PythonException
-from pyspark.testing.sqlutils import (
- ReusedSQLTestCase,
+from pyspark.testing.utils import (
+ have_numpy,
+ numpy_requirement_message,
have_pyarrow,
pyarrow_requirement_message,
)
+from pyspark.testing.sqlutils import ReusedSQLTestCase
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
@@ -813,6 +815,7 @@
[row] = self.spark.sql("SELECT randomArrowUDF(1)").collect()
self.assertEqual(row[0], 7)
+ @unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_nondeterministic_arrow_udf(self):
import pyarrow as pa
@@ -835,6 +838,7 @@
self.assertEqual(random_udf.deterministic, False)
self.assertTrue(result1["plus_ten(rand)"].equals(result1["rand"] + 10))
+ @unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_nondeterministic_arrow_udf_in_aggregate(self):
with self.quiet():
df = self.spark.range(10)
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
index b543c56..1d30159 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
@@ -22,11 +22,13 @@
from pyspark.sql import functions as sf
from pyspark.sql.window import Window
from pyspark.errors import AnalysisException, PythonException, PySparkTypeError
-from pyspark.testing.sqlutils import (
- ReusedSQLTestCase,
+from pyspark.testing.utils import (
+ have_numpy,
+ numpy_requirement_message,
have_pyarrow,
pyarrow_requirement_message,
)
+from pyspark.testing.sqlutils import ReusedSQLTestCase
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
@@ -384,6 +386,7 @@
self.assertEqual(expected1.collect(), result1.collect())
+ @unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_named_arguments(self):
df = self.data
weighted_mean = self.arrow_agg_weighted_mean_udf
@@ -427,6 +430,7 @@
).collect(),
)
+ @unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_named_arguments_negative(self):
df = self.data
weighted_mean = self.arrow_agg_weighted_mean_udf
@@ -718,6 +722,7 @@
# Integer value 2147483657 not in range: -2147483648 to 2147483647
result3.collect()
+ @unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_return_numpy_scalar(self):
import numpy as np
import pyarrow as pa