[SPARK-49713][PYTHON][CONNECT] Make function `count_min_sketch` accept number arguments
### What changes were proposed in this pull request?
1, Make function `count_min_sketch` accept number arguments;
2, Make argument `seed` optional;
3, fix the type hints of `eps/confidence/seed` from `ColumnOrName` to `Column`, because they require a foldable value and actually do not accept column name:
```
In [3]: from pyspark.sql import functions as sf
In [4]: df = spark.range(10000).withColumn("seed", sf.lit(1).cast("int"))
In [5]: df.select(sf.hex(sf.count_min_sketch("id", sf.lit(0.5), sf.lit(0.5), "seed")))
...
AnalysisException: [DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "count_min_sketch(id, 0.5, 0.5, seed)" due to data type mismatch: the input `seed` should be a foldable "INT" expression; however, got "seed". SQLSTATE: 42K09;
'Aggregate [unresolvedalias('hex(count_min_sketch(id#1L, 0.5, 0.5, seed#2, 0, 0)))]
+- Project [id#1L, cast(1 as int) AS seed#2]
+- Range (0, 10000, step=1, splits=Some(12))
...
```
### Why are the changes needed?
1, seed is optional in other similar functions;
2, existing type hint is `ColumnOrName` which is misleading since column name is not actually supported
### Does this PR introduce _any_ user-facing change?
yes, it support number arguments
### How was this patch tested?
updated doctests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #48157 from zhengruifeng/py_fix_count_min_sketch.
Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py
index 2870d9c..7fed175 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -71,6 +71,7 @@
StringType,
)
from pyspark.sql.utils import enum_to_value as _enum_to_value
+from pyspark.util import JVM_INT_MAX
# The implementation of pandas_udf is embedded in pyspark.sql.function.pandas_udf
# for code reuse.
@@ -1126,11 +1127,12 @@
def count_min_sketch(
col: "ColumnOrName",
- eps: "ColumnOrName",
- confidence: "ColumnOrName",
- seed: "ColumnOrName",
+ eps: Union[Column, float],
+ confidence: Union[Column, float],
+ seed: Optional[Union[Column, int]] = None,
) -> Column:
- return _invoke_function_over_columns("count_min_sketch", col, eps, confidence, seed)
+ _seed = lit(random.randint(0, JVM_INT_MAX)) if seed is None else lit(seed)
+ return _invoke_function_over_columns("count_min_sketch", col, lit(eps), lit(confidence), _seed)
count_min_sketch.__doc__ = pysparkfuncs.count_min_sketch.__doc__
diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py
index c0730b1..5f8d1c2 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -6015,9 +6015,9 @@
@_try_remote_functions
def count_min_sketch(
col: "ColumnOrName",
- eps: "ColumnOrName",
- confidence: "ColumnOrName",
- seed: "ColumnOrName",
+ eps: Union[Column, float],
+ confidence: Union[Column, float],
+ seed: Optional[Union[Column, int]] = None,
) -> Column:
"""
Returns a count-min sketch of a column with the given esp, confidence and seed.
@@ -6031,13 +6031,24 @@
----------
col : :class:`~pyspark.sql.Column` or str
target column to compute on.
- eps : :class:`~pyspark.sql.Column` or str
+ eps : :class:`~pyspark.sql.Column` or float
relative error, must be positive
- confidence : :class:`~pyspark.sql.Column` or str
+
+ .. versionchanged:: 4.0.0
+ `eps` now accepts float value.
+
+ confidence : :class:`~pyspark.sql.Column` or float
confidence, must be positive and less than 1.0
- seed : :class:`~pyspark.sql.Column` or str
+
+ .. versionchanged:: 4.0.0
+ `confidence` now accepts float value.
+
+ seed : :class:`~pyspark.sql.Column` or int, optional
random seed
+ .. versionchanged:: 4.0.0
+ `seed` now accepts int value.
+
Returns
-------
:class:`~pyspark.sql.Column`
@@ -6045,12 +6056,48 @@
Examples
--------
- >>> df = spark.createDataFrame([[1], [2], [1]], ['data'])
- >>> df = df.agg(count_min_sketch(df.data, lit(0.5), lit(0.5), lit(1)).alias('sketch'))
- >>> df.select(hex(df.sketch).alias('r')).collect()
- [Row(r='0000000100000000000000030000000100000004000000005D8D6AB90000000000000000000000000000000200000000000000010000000000000000')]
- """
- return _invoke_function_over_columns("count_min_sketch", col, eps, confidence, seed)
+ Example 1: Using columns as arguments
+
+ >>> from pyspark.sql import functions as sf
+ >>> spark.range(100).select(
+ ... sf.hex(sf.count_min_sketch(sf.col("id"), sf.lit(3.0), sf.lit(0.1), sf.lit(1)))
+ ... ).show(truncate=False)
+ +------------------------------------------------------------------------+
+ |hex(count_min_sketch(id, 3.0, 0.1, 1)) |
+ +------------------------------------------------------------------------+
+ |0000000100000000000000640000000100000001000000005D8D6AB90000000000000064|
+ +------------------------------------------------------------------------+
+
+ Example 2: Using numbers as arguments
+
+ >>> from pyspark.sql import functions as sf
+ >>> spark.range(100).select(
+ ... sf.hex(sf.count_min_sketch("id", 1.0, 0.3, 2))
+ ... ).show(truncate=False)
+ +----------------------------------------------------------------------------------------+
+ |hex(count_min_sketch(id, 1.0, 0.3, 2)) |
+ +----------------------------------------------------------------------------------------+
+ |0000000100000000000000640000000100000002000000005D96391C00000000000000320000000000000032|
+ +----------------------------------------------------------------------------------------+
+
+ Example 3: Using a random seed
+
+ >>> from pyspark.sql import functions as sf
+ >>> spark.range(100).select(
+ ... sf.hex(sf.count_min_sketch("id", sf.lit(1.5), 0.6))
+ ... ).show(truncate=False) # doctest: +SKIP
+ +----------------------------------------------------------------------------------------------------------------------------------------+
+ |hex(count_min_sketch(id, 1.5, 0.6, 2120704260)) |
+ +----------------------------------------------------------------------------------------------------------------------------------------+
+ |0000000100000000000000640000000200000002000000005ADECCEE00000000153EBE090000000000000033000000000000003100000000000000320000000000000032|
+ +----------------------------------------------------------------------------------------------------------------------------------------+
+ """ # noqa: E501
+ _eps = lit(eps)
+ _conf = lit(confidence)
+ if seed is None:
+ return _invoke_function_over_columns("count_min_sketch", col, _eps, _conf)
+ else:
+ return _invoke_function_over_columns("count_min_sketch", col, _eps, _conf, lit(seed))
@_try_remote_functions
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
index 0266927..0662b8f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
@@ -389,6 +389,18 @@
def count_min_sketch(e: Column, eps: Column, confidence: Column, seed: Column): Column =
Column.fn("count_min_sketch", e, eps, confidence, seed)
+ /**
+ * Returns a count-min sketch of a column with the given esp, confidence and seed. The result is
+ * an array of bytes, which can be deserialized to a `CountMinSketch` before usage. Count-min
+ * sketch is a probabilistic data structure used for cardinality estimation using sub-linear
+ * space.
+ *
+ * @group agg_funcs
+ * @since 4.0.0
+ */
+ def count_min_sketch(e: Column, eps: Column, confidence: Column): Column =
+ count_min_sketch(e, eps, confidence, lit(SparkClassUtils.random.nextInt))
+
private[spark] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
Column.internalFn("collect_top_k", e, lit(num), lit(reverse))