[SPARK-53779][SQL][CONNECT] Implement `transform()` in Column API
### What changes were proposed in this pull request?
Add `transform()` API in Columns API, similar to `Dataset.transform()`:
```
def transform(f: Column => Column): Column
```
### Why are the changes needed?
We want to give users a way to chain their methods, such as
```
df.select($"fruit"
.transform(addPrefix)
.transform(uppercase)
)
```
This pattern is also easier for AI agents to learn and write.
### Does this PR introduce _any_ user-facing change?
Yes. New API is introduced.
### How was this patch tested?
Unit tests.
### Was this patch authored or co-authored using generative AI tooling?
Tests generated by Copilot.
Closes #52537 from Yicong-Huang/feat/transform-in-column-api.
Lead-authored-by: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com>
Co-authored-by: Yicong Huang <17627829+Yicong-Huang@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
index 316b629..56a9787 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
@@ -1424,6 +1424,22 @@
*/
def outer(): Column = Column(internal.LazyExpression(node))
+ /**
+ * Concise syntax for chaining custom transformations.
+ * {{{
+ * def addPrefix(c: Column): Column = concat(lit("prefix_"), c)
+ *
+ * df.select($"name".transform(addPrefix))
+ *
+ * // Chaining multiple transformations
+ * df.select($"name".transform(addPrefix).transform(upper))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 4.1.0
+ */
+ def transform(f: Column => Column): Column = f(this)
+
}
/**
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnTestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnTestSuite.scala
index 863cb58..698057f 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnTestSuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnTestSuite.scala
@@ -209,4 +209,58 @@
testColName(structType1, _.struct(structType1))
import org.apache.spark.util.ArrayImplicits._
testColName(structType2, _.struct(structType2.fields.toImmutableArraySeq: _*))
+
+ test("transform with named function") {
+ val a = fn.col("a")
+ def addOne(c: Column): Column = c + 1
+ val transformed = a.transform(addOne)
+ assert(transformed == (a + 1))
+ }
+
+ test("transform with lambda") {
+ val a = fn.col("a")
+ val transformed = a.transform(c => c * 2)
+ assert(transformed == (a * 2))
+ }
+
+ test("transform chaining") {
+ val a = fn.col("a")
+ val transformed = a.transform(c => c + 1).transform(c => c * 2)
+ assert(transformed == ((a + 1) * 2))
+ }
+
+ test("transform with complex lambda") {
+ val a = fn.col("a")
+ val transformed = a.transform(c => fn.when(c > 10, c * 2).otherwise(c))
+ val expected = fn.when(a > 10, a * 2).otherwise(a)
+ assert(transformed == expected)
+ }
+
+ test("transform with built-in functions") {
+ val a = fn.col("a")
+ val transformed = a.transform(fn.trim).transform(fn.upper)
+ val expected = fn.upper(fn.trim(a))
+ assert(transformed == expected)
+ }
+
+ test("transform with arithmetic operations") {
+ val a = fn.col("a")
+ val transformed = a.transform(_ + 10).transform(_ * 2).transform(_ - 5)
+ assert(transformed == (((a + 10) * 2) - 5))
+ }
+
+ test("transform mixing named functions and lambdas") {
+ val a = fn.col("a")
+ def triple(c: Column): Column = c * 3
+ val transformed = a.transform(triple).transform(c => c + 10)
+ assert(transformed == ((a * 3) + 10))
+ }
+
+ test("transform with nested transform") {
+ val a = fn.col("a")
+ val transformed = a.transform(_.transform(fn.upper))
+ val expected = fn.upper(a)
+ assert(transformed == expected)
+ }
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 75acd23..2e74981 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -3144,4 +3144,22 @@
checkAnswer(df.select($"dd" / ($"num" + 3)),
Seq((Duration.ofDays(2))).toDF())
}
+
+ test("Column.transform: built-in functions") {
+ val df = Seq(" hello ", " world ").toDF("text")
+
+ checkAnswer(
+ df.select($"text".transform(trim).transform(upper)),
+ Seq("HELLO", "WORLD").toDF()
+ )
+ }
+
+ test("Column.transform: lambda functions") {
+ val df = Seq(10, 20, 30).toDF("value")
+
+ checkAnswer(
+ df.select($"value".transform(_ + 5).transform(_ * 2).transform(_ - 10)),
+ Seq(20, 40, 60).toDF()
+ )
+ }
}