[SPARK-48184][PYTHON][CONNECT] Always set the seed of `Dataframe.sample` in Client side
### What changes were proposed in this pull request?
Always set the seed of `Dataframe.sample` in Client side
### Why are the changes needed?
Bug fix
If the seed is not set in Client, it will be set in server side with a random int
https://github.com/apache/spark/blob/c4df12cc884cddefcfcf8324b4d7b9349fb4f6a0/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala#L386
which cause inconsistent results in multiple executions
In Spark Classic:
```
In [1]: df = spark.range(10000).sample(0.1)
In [2]: [df.count() for i in range(10)]
Out[2]: [1006, 1006, 1006, 1006, 1006, 1006, 1006, 1006, 1006, 1006]
```
In Spark Connect:
before:
```
In [1]: df = spark.range(10000).sample(0.1)
In [2]: [df.count() for i in range(10)]
Out[2]: [969, 1005, 958, 996, 987, 1026, 991, 1020, 1012, 979]
```
after:
```
In [1]: df = spark.range(10000).sample(0.1)
In [2]: [df.count() for i in range(10)]
Out[2]: [1032, 1032, 1032, 1032, 1032, 1032, 1032, 1032, 1032, 1032]
```
### Does this PR introduce _any_ user-facing change?
yes, bug fix
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #46456 from zhengruifeng/py_connect_sample_seed.
Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index f9a209d..843c92a 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -813,7 +813,7 @@
if withReplacement is None:
withReplacement = False
- seed = int(seed) if seed is not None else None
+ seed = int(seed) if seed is not None else random.randint(0, sys.maxsize)
return DataFrame(
plan.Sample(
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py
index 09c3171..e8d04ae 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -443,7 +443,7 @@
self.assertEqual(plan.root.sample.lower_bound, 0.0)
self.assertEqual(plan.root.sample.upper_bound, 0.3)
self.assertEqual(plan.root.sample.with_replacement, False)
- self.assertEqual(plan.root.sample.HasField("seed"), False)
+ self.assertEqual(plan.root.sample.HasField("seed"), True)
self.assertEqual(plan.root.sample.deterministic_order, False)
plan = (
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index 16dd0d2..f491b49 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -430,6 +430,11 @@
IllegalArgumentException, lambda: self.spark.range(1).sample(-1.0).count()
)
+ def test_sample_with_random_seed(self):
+ df = self.spark.range(10000).sample(0.1)
+ cnts = [df.count() for i in range(10)]
+ self.assertEqual(1, len(set(cnts)))
+
def test_toDF_with_string(self):
df = self.spark.createDataFrame([("John", 30), ("Alice", 25), ("Bob", 28)])
data = [("John", 30), ("Alice", 25), ("Bob", 28)]