[SPARK-49422][CONNECT][SQL] Create a shared interface for KeyValueGroupedDataset

### What changes were proposed in this pull request?
This PR creates a shared interface for KeyValueGroupedDataset.

### Why are the changes needed?
We are creating a shared Scala Spark SQL interface for Classic and Connect.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #47960 from hvanhovell/SPARK-49422.

Authored-by: Herman van Hovell <herman@databricks.com>
Signed-off-by: Herman van Hovell <herman@databricks.com>
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index ce21f18..4450448 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -33,11 +33,11 @@
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
 import org.apache.spark.sql.catalyst.expressions.OrderUtils
 import org.apache.spark.sql.connect.client.SparkResult
-import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter, UdfUtils}
+import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter}
 import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
 import org.apache.spark.sql.expressions.SparkUserDefinedFunction
 import org.apache.spark.sql.functions.{struct, to_json}
-import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, DataFrameWriterImpl, UnresolvedAttribute, UnresolvedRegex}
+import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, DataFrameWriterImpl, ToScalaUDF, UDFAdaptors, UnresolvedAttribute, UnresolvedRegex}
 import org.apache.spark.sql.streaming.DataStreamWriter
 import org.apache.spark.sql.types.{Metadata, StructType}
 import org.apache.spark.storage.StorageLevel
@@ -534,7 +534,7 @@
    * @since 3.5.0
    */
   def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
-    groupByKey(UdfUtils.mapFunctionToScalaFunc(func))(encoder)
+    groupByKey(ToScalaUDF(func))(encoder)
 
   /** @inheritdoc */
   @scala.annotation.varargs
@@ -865,17 +865,17 @@
 
   /** @inheritdoc */
   def filter(f: FilterFunction[T]): Dataset[T] = {
-    filter(UdfUtils.filterFuncToScalaFunc(f))
+    filter(ToScalaUDF(f))
   }
 
   /** @inheritdoc */
   def map[U: Encoder](f: T => U): Dataset[U] = {
-    mapPartitions(UdfUtils.mapFuncToMapPartitionsAdaptor(f))
+    mapPartitions(UDFAdaptors.mapToMapPartitions(f))
   }
 
   /** @inheritdoc */
   def map[U](f: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
-    map(UdfUtils.mapFunctionToScalaFunc(f))(encoder)
+    mapPartitions(UDFAdaptors.mapToMapPartitions(f))(encoder)
   }
 
   /** @inheritdoc */
@@ -893,24 +893,10 @@
   }
 
   /** @inheritdoc */
-  def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
-    mapPartitions(UdfUtils.mapPartitionsFuncToScalaFunc(f))(encoder)
-  }
-
-  /** @inheritdoc */
-  override def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] =
-    mapPartitions(UdfUtils.flatMapFuncToMapPartitionsAdaptor(func))
-
-  /** @inheritdoc */
-  override def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
-    flatMap(UdfUtils.flatMapFuncToScalaFunc(f))(encoder)
-  }
-
-  /** @inheritdoc */
   @deprecated("use flatMap() or select() with functions.explode() instead", "3.5.0")
   def explode[A <: Product: TypeTag](input: Column*)(f: Row => IterableOnce[A]): DataFrame = {
     val generator = SparkUserDefinedFunction(
-      UdfUtils.iterableOnceToSeq(f),
+      UDFAdaptors.iterableOnceToSeq(f),
       UnboundRowEncoder :: Nil,
       ScalaReflection.encoderFor[Seq[A]])
     select(col("*"), functions.inline(generator(struct(input: _*))))
@@ -921,31 +907,16 @@
   def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)(
       f: A => IterableOnce[B]): DataFrame = {
     val generator = SparkUserDefinedFunction(
-      UdfUtils.iterableOnceToSeq(f),
+      UDFAdaptors.iterableOnceToSeq(f),
       Nil,
       ScalaReflection.encoderFor[Seq[B]])
     select(col("*"), functions.explode(generator(col(inputColumn))).as((outputColumn)))
   }
 
   /** @inheritdoc */
-  def foreach(f: T => Unit): Unit = {
-    foreachPartition(UdfUtils.foreachFuncToForeachPartitionsAdaptor(f))
-  }
-
-  /** @inheritdoc */
-  override def foreach(func: ForeachFunction[T]): Unit =
-    foreach(UdfUtils.foreachFuncToScalaFunc(func))
-
-  /** @inheritdoc */
   def foreachPartition(f: Iterator[T] => Unit): Unit = {
     // Delegate to mapPartition with empty result.
-    mapPartitions(UdfUtils.foreachPartitionFuncToMapPartitionsAdaptor(f))(RowEncoder(Seq.empty))
-      .collect()
-  }
-
-  /** @inheritdoc */
-  override def foreachPartition(func: ForeachPartitionFunction[T]): Unit = {
-    foreachPartition(UdfUtils.foreachPartitionFuncToScalaFunc(func))
+    mapPartitions(UDFAdaptors.foreachPartitionToMapPartitions(f))(NullEncoder).collect()
   }
 
   /** @inheritdoc */
@@ -1465,6 +1436,22 @@
     super.dropDuplicatesWithinWatermark(col1, cols: _*)
 
   /** @inheritdoc */
+  override def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] =
+    super.mapPartitions(f, encoder)
+
+  /** @inheritdoc */
+  override def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] =
+    super.flatMap(func)
+
+  /** @inheritdoc */
+  override def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
+    super.flatMap(f, encoder)
+
+  /** @inheritdoc */
+  override def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
+    super.foreachPartition(func)
+
+  /** @inheritdoc */
   @scala.annotation.varargs
   override def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] =
     super.repartition(numPartitions, partitionExprs: _*)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 04b620b..aef7efb 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -19,8 +19,8 @@
 
 import java.util.Arrays
 
+import scala.annotation.unused
 import scala.jdk.CollectionConverters._
-import scala.language.existentials
 
 import org.apache.spark.api.java.function._
 import org.apache.spark.connect.proto
@@ -30,6 +30,7 @@
 import org.apache.spark.sql.expressions.SparkUserDefinedFunction
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.toExpr
+import org.apache.spark.sql.internal.UDFAdaptors
 import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode}
 
 /**
@@ -39,7 +40,11 @@
  *
  * @since 3.5.0
  */
-class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
+class KeyValueGroupedDataset[K, V] private[sql] ()
+    extends api.KeyValueGroupedDataset[K, V, Dataset] {
+  type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL]
+
+  private def unsupported(): Nothing = throw new UnsupportedOperationException()
 
   /**
    * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the
@@ -48,499 +53,52 @@
    *
    * @since 3.5.0
    */
-  def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
-    throw new UnsupportedOperationException
-  }
+  def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = unsupported()
 
-  /**
-   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
-   * the data. The grouping key is unchanged by this.
-   *
-   * {{{
-   *   // Create values grouped by key from a Dataset[(K, V)]
-   *   ds.groupByKey(_._1).mapValues(_._2) // Scala
-   * }}}
-   *
-   * @since 3.5.0
-   */
-  def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = {
-    throw new UnsupportedOperationException
-  }
+  /** @inheritdoc */
+  def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] =
+    unsupported()
 
-  /**
-   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
-   * the data. The grouping key is unchanged by this.
-   *
-   * {{{
-   *   // Create Integer values grouped by String key from a Dataset<Tuple2<String, Integer>>
-   *   Dataset<Tuple2<String, Integer>> ds = ...;
-   *   KeyValueGroupedDataset<String, Integer> grouped =
-   *     ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT());
-   * }}}
-   *
-   * @since 3.5.0
-   */
-  def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = {
-    mapValues(UdfUtils.mapFunctionToScalaFunc(func))(encoder)
-  }
+  /** @inheritdoc */
+  def keys: Dataset[K] = unsupported()
 
-  /**
-   * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping over
-   * the Dataset to extract the keys and then running a distinct operation on those.
-   *
-   * @since 3.5.0
-   */
-  def keys: Dataset[K] = {
-    throw new UnsupportedOperationException
-  }
-
-  /**
-   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
-   * function will be passed the group key and an iterator that contains all of the elements in
-   * the group. The function can return an iterator containing elements of an arbitrary type which
-   * will be returned as a new [[Dataset]].
-   *
-   * This function does not support partial aggregation, and as a result requires shuffling all
-   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
-   * key, it is best to use the reduce function or an
-   * `org.apache.spark.sql.expressions#Aggregator`.
-   *
-   * Internally, the implementation will spill to disk if any given group is too large to fit into
-   * memory. However, users must take care to avoid materializing the whole iterator for a group
-   * (for example, by calling `toList`) unless they are sure that this is possible given the
-   * memory constraints of their cluster.
-   *
-   * @since 3.5.0
-   */
-  def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = {
-    flatMapSortedGroups()(f)
-  }
-
-  /**
-   * (Java-specific) Applies the given function to each group of data. For each unique group, the
-   * function will be passed the group key and an iterator that contains all of the elements in
-   * the group. The function can return an iterator containing elements of an arbitrary type which
-   * will be returned as a new [[Dataset]].
-   *
-   * This function does not support partial aggregation, and as a result requires shuffling all
-   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
-   * key, it is best to use the reduce function or an
-   * `org.apache.spark.sql.expressions#Aggregator`.
-   *
-   * Internally, the implementation will spill to disk if any given group is too large to fit into
-   * memory. However, users must take care to avoid materializing the whole iterator for a group
-   * (for example, by calling `toList`) unless they are sure that this is possible given the
-   * memory constraints of their cluster.
-   *
-   * @since 3.5.0
-   */
-  def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
-    flatMapGroups(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
-  }
-
-  /**
-   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
-   * function will be passed the group key and a sorted iterator that contains all of the elements
-   * in the group. The function can return an iterator containing elements of an arbitrary type
-   * which will be returned as a new [[Dataset]].
-   *
-   * This function does not support partial aggregation, and as a result requires shuffling all
-   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
-   * key, it is best to use the reduce function or an
-   * `org.apache.spark.sql.expressions#Aggregator`.
-   *
-   * Internally, the implementation will spill to disk if any given group is too large to fit into
-   * memory. However, users must take care to avoid materializing the whole iterator for a group
-   * (for example, by calling `toList`) unless they are sure that this is possible given the
-   * memory constraints of their cluster.
-   *
-   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
-   * sorted according to the given sort expressions. That sorting does not add computational
-   * complexity.
-   *
-   * @since 3.5.0
-   */
+  /** @inheritdoc */
   def flatMapSortedGroups[U: Encoder](sortExprs: Column*)(
-      f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = {
-    throw new UnsupportedOperationException
-  }
+      f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] =
+    unsupported()
 
-  /**
-   * (Java-specific) Applies the given function to each group of data. For each unique group, the
-   * function will be passed the group key and a sorted iterator that contains all of the elements
-   * in the group. The function can return an iterator containing elements of an arbitrary type
-   * which will be returned as a new [[Dataset]].
-   *
-   * This function does not support partial aggregation, and as a result requires shuffling all
-   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
-   * key, it is best to use the reduce function or an
-   * `org.apache.spark.sql.expressions#Aggregator`.
-   *
-   * Internally, the implementation will spill to disk if any given group is too large to fit into
-   * memory. However, users must take care to avoid materializing the whole iterator for a group
-   * (for example, by calling `toList`) unless they are sure that this is possible given the
-   * memory constraints of their cluster.
-   *
-   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
-   * sorted according to the given sort expressions. That sorting does not add computational
-   * complexity.
-   *
-   * @since 3.5.0
-   */
-  def flatMapSortedGroups[U](
-      SortExprs: Array[Column],
-      f: FlatMapGroupsFunction[K, V, U],
-      encoder: Encoder[U]): Dataset[U] = {
-    import org.apache.spark.util.ArrayImplicits._
-    flatMapSortedGroups(SortExprs.toImmutableArraySeq: _*)(
-      UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
-  }
+  /** @inheritdoc */
+  def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = unsupported()
 
-  /**
-   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
-   * function will be passed the group key and an iterator that contains all of the elements in
-   * the group. The function can return an element of arbitrary type which will be returned as a
-   * new [[Dataset]].
-   *
-   * This function does not support partial aggregation, and as a result requires shuffling all
-   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
-   * key, it is best to use the reduce function or an
-   * `org.apache.spark.sql.expressions#Aggregator`.
-   *
-   * Internally, the implementation will spill to disk if any given group is too large to fit into
-   * memory. However, users must take care to avoid materializing the whole iterator for a group
-   * (for example, by calling `toList`) unless they are sure that this is possible given the
-   * memory constraints of their cluster.
-   *
-   * @since 3.5.0
-   */
-  def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
-    flatMapGroups(UdfUtils.mapGroupsFuncToFlatMapAdaptor(f))
-  }
+  /** @inheritdoc */
+  protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = unsupported()
 
-  /**
-   * (Java-specific) Applies the given function to each group of data. For each unique group, the
-   * function will be passed the group key and an iterator that contains all of the elements in
-   * the group. The function can return an element of arbitrary type which will be returned as a
-   * new [[Dataset]].
-   *
-   * This function does not support partial aggregation, and as a result requires shuffling all
-   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
-   * key, it is best to use the reduce function or an
-   * `org.apache.spark.sql.expressions#Aggregator`.
-   *
-   * Internally, the implementation will spill to disk if any given group is too large to fit into
-   * memory. However, users must take care to avoid materializing the whole iterator for a group
-   * (for example, by calling `toList`) unless they are sure that this is possible given the
-   * memory constraints of their cluster.
-   *
-   * @since 3.5.0
-   */
-  def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
-    mapGroups(UdfUtils.mapGroupsFuncToScalaFunc(f))(encoder)
-  }
-
-  /**
-   * (Scala-specific) Reduces the elements of each group of data using the specified binary
-   * function. The given function must be commutative and associative or the result may be
-   * non-deterministic.
-   *
-   * @since 3.5.0
-   */
-  def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
-    throw new UnsupportedOperationException
-  }
-
-  /**
-   * (Java-specific) Reduces the elements of each group of data using the specified binary
-   * function. The given function must be commutative and associative or the result may be
-   * non-deterministic.
-   *
-   * @since 3.5.0
-   */
-  def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = {
-    reduceGroups(UdfUtils.mapReduceFuncToScalaFunc(f))
-  }
-
-  /**
-   * Internal helper function for building typed aggregations that return tuples. For simplicity
-   * and code reuse, we do this without the help of the type system and then use helper functions
-   * that cast appropriately for the user facing interface.
-   */
-  protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
-    throw new UnsupportedOperationException
-  }
-
-  /**
-   * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key and the
-   * result of computing this aggregation over all elements in the group.
-   *
-   * @since 3.5.0
-   */
-  def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] =
-    aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]]
-
-  /**
-   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
-   * the result of computing these aggregations over all elements in the group.
-   *
-   * @since 3.5.0
-   */
-  def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] =
-    aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]]
-
-  /**
-   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
-   * the result of computing these aggregations over all elements in the group.
-   *
-   * @since 3.5.0
-   */
-  def agg[U1, U2, U3](
-      col1: TypedColumn[V, U1],
-      col2: TypedColumn[V, U2],
-      col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] =
-    aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]]
-
-  /**
-   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
-   * the result of computing these aggregations over all elements in the group.
-   *
-   * @since 3.5.0
-   */
-  def agg[U1, U2, U3, U4](
-      col1: TypedColumn[V, U1],
-      col2: TypedColumn[V, U2],
-      col3: TypedColumn[V, U3],
-      col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] =
-    aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]]
-
-  /**
-   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
-   * the result of computing these aggregations over all elements in the group.
-   *
-   * @since 3.5.0
-   */
-  def agg[U1, U2, U3, U4, U5](
-      col1: TypedColumn[V, U1],
-      col2: TypedColumn[V, U2],
-      col3: TypedColumn[V, U3],
-      col4: TypedColumn[V, U4],
-      col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] =
-    aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5)]]
-
-  /**
-   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
-   * the result of computing these aggregations over all elements in the group.
-   *
-   * @since 3.5.0
-   */
-  def agg[U1, U2, U3, U4, U5, U6](
-      col1: TypedColumn[V, U1],
-      col2: TypedColumn[V, U2],
-      col3: TypedColumn[V, U3],
-      col4: TypedColumn[V, U4],
-      col5: TypedColumn[V, U5],
-      col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] =
-    aggUntyped(col1, col2, col3, col4, col5, col6)
-      .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6)]]
-
-  /**
-   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
-   * the result of computing these aggregations over all elements in the group.
-   *
-   * @since 3.5.0
-   */
-  def agg[U1, U2, U3, U4, U5, U6, U7](
-      col1: TypedColumn[V, U1],
-      col2: TypedColumn[V, U2],
-      col3: TypedColumn[V, U3],
-      col4: TypedColumn[V, U4],
-      col5: TypedColumn[V, U5],
-      col6: TypedColumn[V, U6],
-      col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] =
-    aggUntyped(col1, col2, col3, col4, col5, col6, col7)
-      .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7)]]
-
-  /**
-   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
-   * the result of computing these aggregations over all elements in the group.
-   *
-   * @since 3.5.0
-   */
-  def agg[U1, U2, U3, U4, U5, U6, U7, U8](
-      col1: TypedColumn[V, U1],
-      col2: TypedColumn[V, U2],
-      col3: TypedColumn[V, U3],
-      col4: TypedColumn[V, U4],
-      col5: TypedColumn[V, U5],
-      col6: TypedColumn[V, U6],
-      col7: TypedColumn[V, U7],
-      col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] =
-    aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8)
-      .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)]]
-
-  /**
-   * Returns a [[Dataset]] that contains a tuple with each key and the number of items present for
-   * that key.
-   *
-   * @since 3.5.0
-   */
-  def count(): Dataset[(K, Long)] = agg(functions.count("*"))
-
-  /**
-   * (Scala-specific) Applies the given function to each cogrouped data. For each unique group,
-   * the function will be passed the grouping key and 2 iterators containing all elements in the
-   * group from [[Dataset]] `this` and `other`. The function can return an iterator containing
-   * elements of an arbitrary type which will be returned as a new [[Dataset]].
-   *
-   * @since 3.5.0
-   */
-  def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(
-      f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = {
-    cogroupSorted(other)()()(f)
-  }
-
-  /**
-   * (Java-specific) Applies the given function to each cogrouped data. For each unique group, the
-   * function will be passed the grouping key and 2 iterators containing all elements in the group
-   * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements
-   * of an arbitrary type which will be returned as a new [[Dataset]].
-   *
-   * @since 3.5.0
-   */
-  def cogroup[U, R](
-      other: KeyValueGroupedDataset[K, U],
-      f: CoGroupFunction[K, V, U, R],
-      encoder: Encoder[R]): Dataset[R] = {
-    cogroup(other)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
-  }
-
-  /**
-   * (Scala-specific) Applies the given function to each sorted cogrouped data. For each unique
-   * group, the function will be passed the grouping key and 2 sorted iterators containing all
-   * elements in the group from [[Dataset]] `this` and `other`. The function can return an
-   * iterator containing elements of an arbitrary type which will be returned as a new
-   * [[Dataset]].
-   *
-   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
-   * sorted according to the given sort expressions. That sorting does not add computational
-   * complexity.
-   *
-   * @since 3.5.0
-   */
+  /** @inheritdoc */
   def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(thisSortExprs: Column*)(
-      otherSortExprs: Column*)(
-      f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = {
-    throw new UnsupportedOperationException
-  }
-
-  /**
-   * (Java-specific) Applies the given function to each sorted cogrouped data. For each unique
-   * group, the function will be passed the grouping key and 2 sorted iterators containing all
-   * elements in the group from [[Dataset]] `this` and `other`. The function can return an
-   * iterator containing elements of an arbitrary type which will be returned as a new
-   * [[Dataset]].
-   *
-   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
-   * sorted according to the given sort expressions. That sorting does not add computational
-   * complexity.
-   *
-   * @since 3.5.0
-   */
-  def cogroupSorted[U, R](
-      other: KeyValueGroupedDataset[K, U],
-      thisSortExprs: Array[Column],
-      otherSortExprs: Array[Column],
-      f: CoGroupFunction[K, V, U, R],
-      encoder: Encoder[R]): Dataset[R] = {
-    import org.apache.spark.util.ArrayImplicits._
-    cogroupSorted(other)(thisSortExprs.toImmutableArraySeq: _*)(
-      otherSortExprs.toImmutableArraySeq: _*)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
-  }
+      otherSortExprs: Column*)(f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] =
+    unsupported()
 
   protected[sql] def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder](
       outputMode: Option[OutputMode],
       timeoutConf: GroupStateTimeout,
       initialState: Option[KeyValueGroupedDataset[K, S]],
       isMapGroupWithState: Boolean)(
-      func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = {
-    throw new UnsupportedOperationException
-  }
+      func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = unsupported()
 
-  /**
-   * (Scala-specific) Applies the given function to each group of data, while maintaining a
-   * user-defined per-group state. The result Dataset will represent the objects returned by the
-   * function. For a static batch Dataset, the function will be invoked once per group. For a
-   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
-   * and updates to each group's state will be saved across invocations. See
-   * [[org.apache.spark.sql.streaming.GroupState]] for more details.
-   *
-   * @tparam S
-   *   The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U
-   *   The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func
-   *   Function to be called on every group.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 3.5.0
-   */
+  /** @inheritdoc */
   def mapGroupsWithState[S: Encoder, U: Encoder](
       func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
     mapGroupsWithState(GroupStateTimeout.NoTimeout)(func)
   }
 
-  /**
-   * (Scala-specific) Applies the given function to each group of data, while maintaining a
-   * user-defined per-group state. The result Dataset will represent the objects returned by the
-   * function. For a static batch Dataset, the function will be invoked once per group. For a
-   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
-   * and updates to each group's state will be saved across invocations. See
-   * [[org.apache.spark.sql.streaming.GroupState]] for more details.
-   *
-   * @tparam S
-   *   The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U
-   *   The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func
-   *   Function to be called on every group.
-   * @param timeoutConf
-   *   Timeout configuration for groups that do not receive data for a while.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 3.5.0
-   */
+  /** @inheritdoc */
   def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)(
       func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
     flatMapGroupsWithStateHelper(None, timeoutConf, None, isMapGroupWithState = true)(
-      UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func))
+      UDFAdaptors.mapGroupsWithStateToFlatMapWithState(func))
   }
 
-  /**
-   * (Scala-specific) Applies the given function to each group of data, while maintaining a
-   * user-defined per-group state. The result Dataset will represent the objects returned by the
-   * function. For a static batch Dataset, the function will be invoked once per group. For a
-   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
-   * and updates to each group's state will be saved across invocations. See
-   * [[org.apache.spark.sql.streaming.GroupState]] for more details.
-   *
-   * @tparam S
-   *   The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U
-   *   The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func
-   *   Function to be called on every group.
-   * @param timeoutConf
-   *   Timeout Conf, see GroupStateTimeout for more details
-   * @param initialState
-   *   The user provided state that will be initialized when the first batch of data is processed
-   *   in the streaming query. The user defined function will be called on the state data even if
-   *   there are no other values in the group. To convert a Dataset ds of type Dataset[(K, S)] to
-   *   a KeyValueGroupedDataset[K, S] do {{{ds.groupByKey(x => x._1).mapValues(_._2)}}}
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 3.5.0
-   */
+  /** @inheritdoc */
   def mapGroupsWithState[S: Encoder, U: Encoder](
       timeoutConf: GroupStateTimeout,
       initialState: KeyValueGroupedDataset[K, S])(
@@ -549,134 +107,10 @@
       None,
       timeoutConf,
       Some(initialState),
-      isMapGroupWithState = true)(UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func))
+      isMapGroupWithState = true)(UDFAdaptors.mapGroupsWithStateToFlatMapWithState(func))
   }
 
-  /**
-   * (Java-specific) Applies the given function to each group of data, while maintaining a
-   * user-defined per-group state. The result Dataset will represent the objects returned by the
-   * function. For a static batch Dataset, the function will be invoked once per group. For a
-   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
-   * and updates to each group's state will be saved across invocations. See `GroupState` for more
-   * details.
-   *
-   * @tparam S
-   *   The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U
-   *   The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func
-   *   Function to be called on every group.
-   * @param stateEncoder
-   *   Encoder for the state type.
-   * @param outputEncoder
-   *   Encoder for the output type.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 3.5.0
-   */
-  def mapGroupsWithState[S, U](
-      func: MapGroupsWithStateFunction[K, V, S, U],
-      stateEncoder: Encoder[S],
-      outputEncoder: Encoder[U]): Dataset[U] = {
-    mapGroupsWithState[S, U](UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(
-      stateEncoder,
-      outputEncoder)
-  }
-
-  /**
-   * (Java-specific) Applies the given function to each group of data, while maintaining a
-   * user-defined per-group state. The result Dataset will represent the objects returned by the
-   * function. For a static batch Dataset, the function will be invoked once per group. For a
-   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
-   * and updates to each group's state will be saved across invocations. See `GroupState` for more
-   * details.
-   *
-   * @tparam S
-   *   The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U
-   *   The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func
-   *   Function to be called on every group.
-   * @param stateEncoder
-   *   Encoder for the state type.
-   * @param outputEncoder
-   *   Encoder for the output type.
-   * @param timeoutConf
-   *   Timeout configuration for groups that do not receive data for a while.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 3.5.0
-   */
-  def mapGroupsWithState[S, U](
-      func: MapGroupsWithStateFunction[K, V, S, U],
-      stateEncoder: Encoder[S],
-      outputEncoder: Encoder[U],
-      timeoutConf: GroupStateTimeout): Dataset[U] = {
-    mapGroupsWithState[S, U](timeoutConf)(UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(
-      stateEncoder,
-      outputEncoder)
-  }
-
-  /**
-   * (Java-specific) Applies the given function to each group of data, while maintaining a
-   * user-defined per-group state. The result Dataset will represent the objects returned by the
-   * function. For a static batch Dataset, the function will be invoked once per group. For a
-   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
-   * and updates to each group's state will be saved across invocations. See `GroupState` for more
-   * details.
-   *
-   * @tparam S
-   *   The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U
-   *   The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func
-   *   Function to be called on every group.
-   * @param stateEncoder
-   *   Encoder for the state type.
-   * @param outputEncoder
-   *   Encoder for the output type.
-   * @param timeoutConf
-   *   Timeout configuration for groups that do not receive data for a while.
-   * @param initialState
-   *   The user provided state that will be initialized when the first batch of data is processed
-   *   in the streaming query. The user defined function will be called on the state data even if
-   *   there are no other values in the group.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 3.5.0
-   */
-  def mapGroupsWithState[S, U](
-      func: MapGroupsWithStateFunction[K, V, S, U],
-      stateEncoder: Encoder[S],
-      outputEncoder: Encoder[U],
-      timeoutConf: GroupStateTimeout,
-      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
-    mapGroupsWithState[S, U](timeoutConf, initialState)(
-      UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(stateEncoder, outputEncoder)
-  }
-
-  /**
-   * (Scala-specific) Applies the given function to each group of data, while maintaining a
-   * user-defined per-group state. The result Dataset will represent the objects returned by the
-   * function. For a static batch Dataset, the function will be invoked once per group. For a
-   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
-   * and updates to each group's state will be saved across invocations. See `GroupState` for more
-   * details.
-   *
-   * @tparam S
-   *   The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U
-   *   The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func
-   *   Function to be called on every group.
-   * @param outputMode
-   *   The output mode of the function.
-   * @param timeoutConf
-   *   Timeout configuration for groups that do not receive data for a while.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 3.5.0
-   */
+  /** @inheritdoc */
   def flatMapGroupsWithState[S: Encoder, U: Encoder](
       outputMode: OutputMode,
       timeoutConf: GroupStateTimeout)(
@@ -688,33 +122,7 @@
       isMapGroupWithState = false)(func)
   }
 
-  /**
-   * (Scala-specific) Applies the given function to each group of data, while maintaining a
-   * user-defined per-group state. The result Dataset will represent the objects returned by the
-   * function. For a static batch Dataset, the function will be invoked once per group. For a
-   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
-   * and updates to each group's state will be saved across invocations. See `GroupState` for more
-   * details.
-   *
-   * @tparam S
-   *   The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U
-   *   The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func
-   *   Function to be called on every group.
-   * @param outputMode
-   *   The output mode of the function.
-   * @param timeoutConf
-   *   Timeout configuration for groups that do not receive data for a while.
-   * @param initialState
-   *   The user provided state that will be initialized when the first batch of data is processed
-   *   in the streaming query. The user defined function will be called on the state data even if
-   *   there are no other values in the group. To covert a Dataset `ds` of type of type
-   *   `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use
-   *   {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} See [[Encoder]] for more details on what
-   *   types are encodable to Spark SQL.
-   * @since 3.5.0
-   */
+  /** @inheritdoc */
   def flatMapGroupsWithState[S: Encoder, U: Encoder](
       outputMode: OutputMode,
       timeoutConf: GroupStateTimeout,
@@ -727,201 +135,244 @@
       isMapGroupWithState = false)(func)
   }
 
-  /**
-   * (Java-specific) Applies the given function to each group of data, while maintaining a
-   * user-defined per-group state. The result Dataset will represent the objects returned by the
-   * function. For a static batch Dataset, the function will be invoked once per group. For a
-   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
-   * and updates to each group's state will be saved across invocations. See `GroupState` for more
-   * details.
-   *
-   * @tparam S
-   *   The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U
-   *   The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func
-   *   Function to be called on every group.
-   * @param outputMode
-   *   The output mode of the function.
-   * @param stateEncoder
-   *   Encoder for the state type.
-   * @param outputEncoder
-   *   Encoder for the output type.
-   * @param timeoutConf
-   *   Timeout configuration for groups that do not receive data for a while.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 3.5.0
-   */
-  def flatMapGroupsWithState[S, U](
+  /** @inheritdoc */
+  private[sql] def transformWithState[U: Encoder](
+      statefulProcessor: StatefulProcessor[K, V, U],
+      timeMode: TimeMode,
+      outputMode: OutputMode): Dataset[U] =
+    unsupported()
+
+  /** @inheritdoc */
+  private[sql] def transformWithState[U: Encoder, S: Encoder](
+      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+      timeMode: TimeMode,
+      outputMode: OutputMode,
+      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] =
+    unsupported()
+
+  /** @inheritdoc */
+  override private[sql] def transformWithState[U: Encoder](
+      statefulProcessor: StatefulProcessor[K, V, U],
+      eventTimeColumnName: String,
+      outputMode: OutputMode): Dataset[U] = unsupported()
+
+  /** @inheritdoc */
+  override private[sql] def transformWithState[U: Encoder, S: Encoder](
+      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+      eventTimeColumnName: String,
+      outputMode: OutputMode,
+      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = unsupported()
+
+  // Overrides...
+  /** @inheritdoc */
+  override def mapValues[W](
+      func: MapFunction[V, W],
+      encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = super.mapValues(func, encoder)
+
+  /** @inheritdoc */
+  override def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] =
+    super.flatMapGroups(f)
+
+  /** @inheritdoc */
+  override def flatMapGroups[U](
+      f: FlatMapGroupsFunction[K, V, U],
+      encoder: Encoder[U]): Dataset[U] = super.flatMapGroups(f, encoder)
+
+  /** @inheritdoc */
+  override def flatMapSortedGroups[U](
+      SortExprs: Array[Column],
+      f: FlatMapGroupsFunction[K, V, U],
+      encoder: Encoder[U]): Dataset[U] = super.flatMapSortedGroups(SortExprs, f, encoder)
+
+  /** @inheritdoc */
+  override def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = super.mapGroups(f)
+
+  /** @inheritdoc */
+  override def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] =
+    super.mapGroups(f, encoder)
+
+  /** @inheritdoc */
+  override def mapGroupsWithState[S, U](
+      func: MapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U]): Dataset[U] =
+    super.mapGroupsWithState(func, stateEncoder, outputEncoder)
+
+  /** @inheritdoc */
+  override def mapGroupsWithState[S, U](
+      func: MapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U],
+      timeoutConf: GroupStateTimeout): Dataset[U] =
+    super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf)
+
+  /** @inheritdoc */
+  override def mapGroupsWithState[S, U](
+      func: MapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U],
+      timeoutConf: GroupStateTimeout,
+      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] =
+    super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf, initialState)
+
+  /** @inheritdoc */
+  override def flatMapGroupsWithState[S, U](
       func: FlatMapGroupsWithStateFunction[K, V, S, U],
       outputMode: OutputMode,
       stateEncoder: Encoder[S],
       outputEncoder: Encoder[U],
-      timeoutConf: GroupStateTimeout): Dataset[U] = {
-    val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func)
-    flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder)
-  }
+      timeoutConf: GroupStateTimeout): Dataset[U] =
+    super.flatMapGroupsWithState(func, outputMode, stateEncoder, outputEncoder, timeoutConf)
 
-  /**
-   * (Java-specific) Applies the given function to each group of data, while maintaining a
-   * user-defined per-group state. The result Dataset will represent the objects returned by the
-   * function. For a static batch Dataset, the function will be invoked once per group. For a
-   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
-   * and updates to each group's state will be saved across invocations. See `GroupState` for more
-   * details.
-   *
-   * @tparam S
-   *   The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U
-   *   The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func
-   *   Function to be called on every group.
-   * @param outputMode
-   *   The output mode of the function.
-   * @param stateEncoder
-   *   Encoder for the state type.
-   * @param outputEncoder
-   *   Encoder for the output type.
-   * @param timeoutConf
-   *   Timeout configuration for groups that do not receive data for a while.
-   * @param initialState
-   *   The user provided state that will be initialized when the first batch of data is processed
-   *   in the streaming query. The user defined function will be called on the state data even if
-   *   there are no other values in the group. To covert a Dataset `ds` of type of type
-   *   `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use
-   *   {{{ds.groupByKey(x => x._1).mapValues(_._2)}}}
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 3.5.0
-   */
-  def flatMapGroupsWithState[S, U](
+  /** @inheritdoc */
+  override def flatMapGroupsWithState[S, U](
       func: FlatMapGroupsWithStateFunction[K, V, S, U],
       outputMode: OutputMode,
       stateEncoder: Encoder[S],
       outputEncoder: Encoder[U],
       timeoutConf: GroupStateTimeout,
-      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
-    val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func)
-    flatMapGroupsWithState[S, U](outputMode, timeoutConf, initialState)(f)(
-      stateEncoder,
-      outputEncoder)
-  }
+      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = super.flatMapGroupsWithState(
+    func,
+    outputMode,
+    stateEncoder,
+    outputEncoder,
+    timeoutConf,
+    initialState)
 
-  /**
-   * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state
-   * API v2. We allow the user to act on per-group set of input rows along with keyed state and
-   * the user can choose to output/return 0 or more rows. For a streaming dataframe, we will
-   * repeatedly invoke the interface methods for new rows in each trigger and the user's
-   * state/state variables will be stored persistently across invocations. Currently this operator
-   * is not supported with Spark Connect.
-   *
-   * @tparam U
-   *   The type of the output objects. Must be encodable to Spark SQL types.
-   * @param statefulProcessor
-   *   Instance of statefulProcessor whose functions will be invoked by the operator.
-   * @param timeMode
-   *   The time mode semantics of the stateful processor for timers and TTL.
-   * @param outputMode
-   *   The output mode of the stateful processor.
-   */
-  private[sql] def transformWithState[U: Encoder](
-      statefulProcessor: StatefulProcessor[K, V, U],
-      timeMode: TimeMode,
-      outputMode: OutputMode): Dataset[U] = {
-    throw new UnsupportedOperationException
-  }
-
-  /**
-   * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API
-   * v2. We allow the user to act on per-group set of input rows along with keyed state and the
-   * user can choose to output/return 0 or more rows. For a streaming dataframe, we will
-   * repeatedly invoke the interface methods for new rows in each trigger and the user's
-   * state/state variables will be stored persistently across invocations. Currently this operator
-   * is not supported with Spark Connect.
-   *
-   * @tparam U
-   *   The type of the output objects. Must be encodable to Spark SQL types.
-   * @param statefulProcessor
-   *   Instance of statefulProcessor whose functions will be invoked by the operator.
-   * @param timeMode
-   *   The time mode semantics of the stateful processor for timers and TTL.
-   * @param outputMode
-   *   The output mode of the stateful processor.
-   * @param outputEncoder
-   *   Encoder for the output type.
-   */
-  private[sql] def transformWithState[U: Encoder](
+  /** @inheritdoc */
+  override private[sql] def transformWithState[U: Encoder](
       statefulProcessor: StatefulProcessor[K, V, U],
       timeMode: TimeMode,
       outputMode: OutputMode,
-      outputEncoder: Encoder[U]): Dataset[U] = {
-    throw new UnsupportedOperationException
-  }
+      outputEncoder: Encoder[U]) =
+    super.transformWithState(statefulProcessor, timeMode, outputMode, outputEncoder)
 
-  /**
-   * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state
-   * API v2. Functions as the function above, but with additional initial state. Currently this
-   * operator is not supported with Spark Connect.
-   *
-   * @tparam U
-   *   The type of the output objects. Must be encodable to Spark SQL types.
-   * @tparam S
-   *   The type of initial state objects. Must be encodable to Spark SQL types.
-   * @param statefulProcessor
-   *   Instance of statefulProcessor whose functions will be invoked by the operator.
-   * @param timeMode
-   *   The time mode semantics of the stateful processor for timers and TTL.
-   * @param outputMode
-   *   The output mode of the stateful processor.
-   * @param initialState
-   *   User provided initial state that will be used to initiate state for the query in the first
-   *   batch.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   */
-  private[sql] def transformWithState[U: Encoder, S: Encoder](
-      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
-      timeMode: TimeMode,
+  /** @inheritdoc */
+  override private[sql] def transformWithState[U: Encoder](
+      statefulProcessor: StatefulProcessor[K, V, U],
+      eventTimeColumnName: String,
       outputMode: OutputMode,
-      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
-    throw new UnsupportedOperationException
-  }
+      outputEncoder: Encoder[U]) =
+    super.transformWithState(statefulProcessor, eventTimeColumnName, outputMode, outputEncoder)
 
-  /**
-   * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API
-   * v2. Functions as the function above, but with additional initial state. Currently this
-   * operator is not supported with Spark Connect.
-   *
-   * @tparam U
-   *   The type of the output objects. Must be encodable to Spark SQL types.
-   * @tparam S
-   *   The type of initial state objects. Must be encodable to Spark SQL types.
-   * @param statefulProcessor
-   *   Instance of statefulProcessor whose functions will be invoked by the operator.
-   * @param timeMode
-   *   The time mode semantics of the stateful processor for timers and TTL.
-   * @param outputMode
-   *   The output mode of the stateful processor.
-   * @param initialState
-   *   User provided initial state that will be used to initiate state for the query in the first
-   *   batch.
-   * @param outputEncoder
-   *   Encoder for the output type.
-   * @param initialStateEncoder
-   *   Encoder for the initial state type.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   */
-  private[sql] def transformWithState[U: Encoder, S: Encoder](
+  /** @inheritdoc */
+  override private[sql] def transformWithState[U: Encoder, S: Encoder](
       statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
       timeMode: TimeMode,
       outputMode: OutputMode,
       initialState: KeyValueGroupedDataset[K, S],
       outputEncoder: Encoder[U],
-      initialStateEncoder: Encoder[S]): Dataset[U] = {
-    throw new UnsupportedOperationException
-  }
+      initialStateEncoder: Encoder[S]) = super.transformWithState(
+    statefulProcessor,
+    timeMode,
+    outputMode,
+    initialState,
+    outputEncoder,
+    initialStateEncoder)
+
+  /** @inheritdoc */
+  override private[sql] def transformWithState[U: Encoder, S: Encoder](
+      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+      outputMode: OutputMode,
+      initialState: KeyValueGroupedDataset[K, S],
+      eventTimeColumnName: String,
+      outputEncoder: Encoder[U],
+      initialStateEncoder: Encoder[S]) = super.transformWithState(
+    statefulProcessor,
+    outputMode,
+    initialState,
+    eventTimeColumnName,
+    outputEncoder,
+    initialStateEncoder)
+
+  /** @inheritdoc */
+  override def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = super.reduceGroups(f)
+
+  /** @inheritdoc */
+  override def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = super.agg(col1)
+
+  /** @inheritdoc */
+  override def agg[U1, U2](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = super.agg(col1, col2)
+
+  /** @inheritdoc */
+  override def agg[U1, U2, U3](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = super.agg(col1, col2, col3)
+
+  /** @inheritdoc */
+  override def agg[U1, U2, U3, U4](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = super.agg(col1, col2, col3, col4)
+
+  /** @inheritdoc */
+  override def agg[U1, U2, U3, U4, U5](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] =
+    super.agg(col1, col2, col3, col4, col5)
+
+  /** @inheritdoc */
+  override def agg[U1, U2, U3, U4, U5, U6](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5],
+      col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] =
+    super.agg(col1, col2, col3, col4, col5, col6)
+
+  /** @inheritdoc */
+  override def agg[U1, U2, U3, U4, U5, U6, U7](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5],
+      col6: TypedColumn[V, U6],
+      col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] =
+    super.agg(col1, col2, col3, col4, col5, col6, col7)
+
+  /** @inheritdoc */
+  override def agg[U1, U2, U3, U4, U5, U6, U7, U8](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5],
+      col6: TypedColumn[V, U6],
+      col7: TypedColumn[V, U7],
+      col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] =
+    super.agg(col1, col2, col3, col4, col5, col6, col7, col8)
+
+  /** @inheritdoc */
+  override def count(): Dataset[(K, Long)] = super.count()
+
+  /** @inheritdoc */
+  override def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(
+      f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] =
+    super.cogroup(other)(f)
+
+  /** @inheritdoc */
+  override def cogroup[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] = super.cogroup(other, f, encoder)
+
+  /** @inheritdoc */
+  override def cogroupSorted[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      thisSortExprs: Array[Column],
+      otherSortExprs: Array[Column],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] =
+    super.cogroupSorted(other, thisSortExprs, otherSortExprs, f, encoder)
 }
 
 /**
@@ -934,12 +385,11 @@
 private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
     private val sparkSession: SparkSession,
     private val plan: proto.Plan,
-    private val ikEncoder: AgnosticEncoder[IK],
     private val kEncoder: AgnosticEncoder[K],
     private val ivEncoder: AgnosticEncoder[IV],
     private val vEncoder: AgnosticEncoder[V],
     private val groupingExprs: java.util.List[proto.Expression],
-    private val valueMapFunc: IV => V,
+    private val valueMapFunc: Option[IV => V],
     private val keysFunc: () => Dataset[IK])
     extends KeyValueGroupedDataset[K, V] {
   import sparkSession.RichColumn
@@ -948,7 +398,6 @@
     new KeyValueGroupedDatasetImpl[L, V, IK, IV](
       sparkSession,
       plan,
-      ikEncoder,
       encoderFor[L],
       ivEncoder,
       vEncoder,
@@ -961,12 +410,13 @@
     new KeyValueGroupedDatasetImpl[K, W, IK, IV](
       sparkSession,
       plan,
-      ikEncoder,
       kEncoder,
       ivEncoder,
       encoderFor[W],
       groupingExprs,
-      valueMapFunc.andThen(valueFunc),
+      valueMapFunc
+        .map(_.andThen(valueFunc))
+        .orElse(Option(valueFunc.asInstanceOf[IV => W])),
       keysFunc)
   }
 
@@ -979,8 +429,7 @@
   override def flatMapSortedGroups[U: Encoder](sortExprs: Column*)(
       f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = {
     // Apply mapValues changes to the udf
-    val nf =
-      if (valueMapFunc == UdfUtils.identical()) f else UdfUtils.mapValuesAdaptor(f, valueMapFunc)
+    val nf = UDFAdaptors.flatMapGroupsWithMappedValues(f, valueMapFunc)
     val outputEncoder = encoderFor[U]
     sparkSession.newDataset[U](outputEncoder) { builder =>
       builder.getGroupMapBuilder
@@ -994,10 +443,9 @@
   override def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(
       thisSortExprs: Column*)(otherSortExprs: Column*)(
       f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = {
-    assert(other.isInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, _]])
-    val otherImpl = other.asInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, _]]
+    val otherImpl = other.asInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, Any]]
     // Apply mapValues changes to the udf
-    val nf = UdfUtils.mapValuesAdaptor(f, valueMapFunc, otherImpl.valueMapFunc)
+    val nf = UDFAdaptors.coGroupWithMappedValues(f, valueMapFunc, otherImpl.valueMapFunc)
     val outputEncoder = encoderFor[R]
     sparkSession.newDataset[R](outputEncoder) { builder =>
       builder.getCoGroupMapBuilder
@@ -1012,8 +460,7 @@
   }
 
   override protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
-    // TODO(SPARK-43415): For each column, apply the valueMap func first
-    // apply keyAs change
+    // TODO(SPARK-43415): For each column, apply the valueMap func first...
     val rEnc = ProductEncoder.tuple(kEncoder +: columns.map(c => encoderFor(c.encoder)))
     sparkSession.newDataset(rEnc) { builder =>
       builder.getAggregateBuilder
@@ -1047,22 +494,15 @@
       throw new IllegalArgumentException("The output mode of function should be append or update")
     }
 
-    if (initialState.isDefined) {
-      assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]])
-    }
-
     val initialStateImpl = if (initialState.isDefined) {
+      assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]])
       initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]
     } else {
       null
     }
 
     val outputEncoder = encoderFor[U]
-    val nf = if (valueMapFunc == UdfUtils.identical()) {
-      func
-    } else {
-      UdfUtils.mapValuesAdaptor(func, valueMapFunc)
-    }
+    val nf = UDFAdaptors.flatMapGroupsWithStateWithMappedValues(func, valueMapFunc)
 
     sparkSession.newDataset[U](outputEncoder) { builder =>
       val groupMapBuilder = builder.getGroupMapBuilder
@@ -1097,6 +537,7 @@
    * We cannot deserialize a connect [[KeyValueGroupedDataset]] because of a class clash on the
    * server side. We null out the instance for now.
    */
+  @unused("this is used by java serialization")
   private def writeReplace(): Any = null
 }
 
@@ -1114,11 +555,10 @@
       session,
       ds.plan,
       kEncoder,
-      kEncoder,
       ds.agnosticEncoder,
       ds.agnosticEncoder,
       Arrays.asList(toExpr(gf.apply(col("*")))),
-      UdfUtils.identical(),
+      None,
       () => ds.map(groupingFunc)(kEncoder))
   }
 
@@ -1137,11 +577,10 @@
       session,
       df.plan,
       kEncoder,
-      kEncoder,
       vEncoder,
       vEncoder,
       (Seq(dummyGroupingFunc) ++ groupingExprs).map(toExpr).asJava,
-      UdfUtils.identical(),
+      None,
       () => df.select(groupingExprs: _*).as(kEncoder))
   }
 }
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index c9b011c..ea13635 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -80,12 +80,7 @@
     colNames.map(df.col)
   }
 
-  /**
-   * Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions of
-   * current `RelationalGroupedDataset`.
-   *
-   * @since 3.5.0
-   */
+  /** @inheritdoc */
   def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = {
     KeyValueGroupedDatasetImpl[K, T](df, encoderFor[K], encoderFor[T], groupingExprs)
   }
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
index 49f77a1..c982609 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
@@ -24,6 +24,7 @@
 import org.apache.spark.annotation.{DeveloperApi, Stable}
 import org.apache.spark.api.java.function.{FilterFunction, FlatMapFunction, ForeachFunction, ForeachPartitionFunction, MapFunction, MapPartitionsFunction, ReduceFunction}
 import org.apache.spark.sql.{functions, AnalysisException, Column, DataFrameWriter, Encoder, Observation, Row, TypedColumn}
+import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors}
 import org.apache.spark.sql.types.{Metadata, StructType}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.ArrayImplicits._
@@ -251,7 +252,6 @@
    * @since 1.6.0
    */
   def explain(extended: Boolean): Unit = if (extended) {
-    // TODO move ExplainMode?
     explain("extended")
   } else {
     explain("simple")
@@ -1384,7 +1384,7 @@
    * @group action
    * @since 1.6.0
    */
-  def reduce(func: ReduceFunction[T]): T = reduce(func.call)
+  def reduce(func: ReduceFunction[T]): T = reduce(ToScalaUDF(func))
 
   /**
    * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
@@ -2437,7 +2437,8 @@
    * @group typedrel
    * @since 1.6.0
    */
-  def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): DS[U]
+  def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): DS[U] =
+    mapPartitions(ToScalaUDF(f))(encoder)
 
   /**
    * (Scala-specific)
@@ -2448,7 +2449,7 @@
    * @since 1.6.0
    */
   def flatMap[U: Encoder](func: T => IterableOnce[U]): DS[U] =
-    mapPartitions(_.flatMap(func))
+    mapPartitions(UDFAdaptors.flatMapToMapPartitions[T, U](func))
 
   /**
    * (Java-specific)
@@ -2459,8 +2460,7 @@
    * @since 1.6.0
    */
   def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): DS[U] = {
-    val func: T => Iterator[U] = x => f.call(x).asScala
-    flatMap(func)(encoder)
+    mapPartitions(UDFAdaptors.flatMapToMapPartitions(f))(encoder)
   }
 
   /**
@@ -2469,7 +2469,9 @@
    * @group action
    * @since 1.6.0
    */
-  def foreach(f: T => Unit): Unit
+  def foreach(f: T => Unit): Unit = {
+    foreachPartition(UDFAdaptors.foreachToForeachPartition(f))
+  }
 
   /**
    * (Java-specific)
@@ -2478,7 +2480,9 @@
    * @group action
    * @since 1.6.0
    */
-  def foreach(func: ForeachFunction[T]): Unit = foreach(func.call)
+  def foreach(func: ForeachFunction[T]): Unit = {
+    foreachPartition(UDFAdaptors.foreachToForeachPartition(func))
+  }
 
   /**
    * Applies a function `f` to each partition of this Dataset.
@@ -2496,7 +2500,7 @@
    * @since 1.6.0
    */
   def foreachPartition(func: ForeachPartitionFunction[T]): Unit = {
-    foreachPartition((it: Iterator[T]) => func.call(it.asJava))
+    foreachPartition(ToScalaUDF(func))
   }
 
   /**
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala
new file mode 100644
index 0000000..5e73da2
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala
@@ -0,0 +1,955 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.api
+
+import org.apache.spark.api.java.function.{CoGroupFunction, FlatMapGroupsFunction, FlatMapGroupsWithStateFunction, MapFunction, MapGroupsFunction, MapGroupsWithStateFunction, ReduceFunction}
+import org.apache.spark.sql.{Column, Encoder, TypedColumn}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveLongEncoder
+import org.apache.spark.sql.functions.{count => cnt, lit}
+import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors}
+import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode}
+
+/**
+ * A [[Dataset]] has been logically grouped by a user specified grouping key.  Users should not
+ * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupByKey` on
+ * an existing [[Dataset]].
+ *
+ * @since 2.0.0
+ */
+abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Serializable {
+  type KVDS[KY, VL] <: KeyValueGroupedDataset[KY, VL, DS]
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the
+   * specified type. The mapping of key columns to the type follows the same rules as `as` on
+   * [[Dataset]].
+   *
+   * @since 1.6.0
+   */
+  def keyAs[L: Encoder]: KVDS[L, V]
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied
+   * to the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create values grouped by key from a Dataset[(K, V)]
+   *   ds.groupByKey(_._1).mapValues(_._2) // Scala
+   * }}}
+   *
+   * @since 2.1.0
+   */
+  def mapValues[W: Encoder](func: V => W): KVDS[K, W]
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied
+   * to the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create Integer values grouped by String key from a Dataset<Tuple2<String, Integer>>
+   *   Dataset<Tuple2<String, Integer>> ds = ...;
+   *   KeyValueGroupedDataset<String, Integer> grouped =
+   *     ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT());
+   * }}}
+   *
+   * @since 2.1.0
+   */
+  def mapValues[W](
+      func: MapFunction[V, W],
+      encoder: Encoder[W]): KVDS[K, W] = {
+    mapValues(ToScalaUDF(func))(encoder)
+  }
+
+  /**
+   * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping
+   * over the Dataset to extract the keys and then running a distinct operation on those.
+   *
+   * @since 1.6.0
+   */
+  def keys: DS[K]
+
+  /**
+   * (Scala-specific)
+   * Applies the given function to each group of data.  For each unique group, the function will
+   * be passed the group key and an iterator that contains all of the elements in the group. The
+   * function can return an iterator containing elements of an arbitrary type which will be returned
+   * as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory.  However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the memory
+   * constraints of their cluster.
+   *
+   * @since 1.6.0
+   */
+  def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): DS[U] = {
+    flatMapSortedGroups(Nil: _*)(f)
+  }
+
+  /**
+   * (Java-specific)
+   * Applies the given function to each group of data.  For each unique group, the function will
+   * be passed the group key and an iterator that contains all of the elements in the group. The
+   * function can return an iterator containing elements of an arbitrary type which will be returned
+   * as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory.  However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the memory
+   * constraints of their cluster.
+   *
+   * @since 1.6.0
+   */
+  def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): DS[U] = {
+    flatMapGroups(ToScalaUDF(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific)
+   * Applies the given function to each group of data.  For each unique group, the function will
+   * be passed the group key and a sorted iterator that contains all of the elements in the group.
+   * The function can return an iterator containing elements of an arbitrary type which will be
+   * returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory.  However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the memory
+   * constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator
+   * to be sorted according to the given sort expressions. That sorting does not add
+   * computational complexity.
+   *
+   * @see [[org.apache.spark.sql.api.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.4.0
+   */
+  def flatMapSortedGroups[U: Encoder](
+      sortExprs: Column*)(
+      f: (K, Iterator[V]) => IterableOnce[U]): DS[U]
+
+  /**
+   * (Java-specific)
+   * Applies the given function to each group of data.  For each unique group, the function will
+   * be passed the group key and a sorted iterator that contains all of the elements in the group.
+   * The function can return an iterator containing elements of an arbitrary type which will be
+   * returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory.  However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the memory
+   * constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator
+   * to be sorted according to the given sort expressions. That sorting does not add
+   * computational complexity.
+   *
+   * @see [[org.apache.spark.sql.api.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.4.0
+   */
+  def flatMapSortedGroups[U](
+      SortExprs: Array[Column],
+      f: FlatMapGroupsFunction[K, V, U],
+      encoder: Encoder[U]): DS[U] = {
+    import org.apache.spark.util.ArrayImplicits._
+    flatMapSortedGroups(SortExprs.toImmutableArraySeq: _*)(ToScalaUDF(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific)
+   * Applies the given function to each group of data.  For each unique group, the function will
+   * be passed the group key and an iterator that contains all of the elements in the group. The
+   * function can return an element of arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory.  However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the memory
+   * constraints of their cluster.
+   *
+   * @since 1.6.0
+   */
+  def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): DS[U] = {
+    flatMapGroups(UDFAdaptors.mapGroupsToFlatMapGroups(f))
+  }
+
+  /**
+   * (Java-specific)
+   * Applies the given function to each group of data.  For each unique group, the function will
+   * be passed the group key and an iterator that contains all of the elements in the group. The
+   * function can return an element of arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory.  However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the memory
+   * constraints of their cluster.
+   *
+   * @since 1.6.0
+   */
+  def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): DS[U] = {
+    mapGroups(ToScalaUDF(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See [[org.apache.spark.sql.streaming.GroupState]] for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func Function to be called on every group.
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 2.2.0
+   */
+  def mapGroupsWithState[S: Encoder, U: Encoder](func: (K, Iterator[V], GroupState[S]) => U): DS[U]
+
+  /**
+   * (Scala-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See [[org.apache.spark.sql.streaming.GroupState]] for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func        Function to be called on every group.
+   * @param timeoutConf Timeout configuration for groups that do not receive data for a while.
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 2.2.0
+   */
+  def mapGroupsWithState[S: Encoder, U: Encoder](
+      timeoutConf: GroupStateTimeout)(
+      func: (K, Iterator[V], GroupState[S]) => U): DS[U]
+
+  /**
+   * (Scala-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See [[org.apache.spark.sql.streaming.GroupState]] for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func Function to be called on every group.
+   * @param timeoutConf Timeout Conf, see GroupStateTimeout for more details
+   * @param initialState The user provided state that will be initialized when the first batch
+   *                     of data is processed in the streaming query. The user defined function
+   *                     will be called on the state data even if there are no other values in
+   *                     the group. To convert a Dataset ds of type Dataset[(K, S)] to a
+   *                     KeyValueGroupedDataset[K, S]
+   *                     do {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 3.2.0
+   */
+  def mapGroupsWithState[S: Encoder, U: Encoder](
+      timeoutConf: GroupStateTimeout,
+      initialState: KVDS[K, S])(
+      func: (K, Iterator[V], GroupState[S]) => U): DS[U]
+
+  /**
+   * (Java-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See `GroupState` for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func          Function to be called on every group.
+   * @param stateEncoder  Encoder for the state type.
+   * @param outputEncoder Encoder for the output type.
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 2.2.0
+   */
+  def mapGroupsWithState[S, U](
+      func: MapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U]): DS[U] = {
+    mapGroupsWithState[S, U](ToScalaUDF(func))(stateEncoder, outputEncoder)
+  }
+
+  /**
+   * (Java-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See `GroupState` for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func          Function to be called on every group.
+   * @param stateEncoder  Encoder for the state type.
+   * @param outputEncoder Encoder for the output type.
+   * @param timeoutConf   Timeout configuration for groups that do not receive data for a while.
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 2.2.0
+   */
+  def mapGroupsWithState[S, U](
+      func: MapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U],
+      timeoutConf: GroupStateTimeout): DS[U] = {
+    mapGroupsWithState[S, U](timeoutConf)(ToScalaUDF(func))(stateEncoder, outputEncoder)
+  }
+
+  /**
+   * (Java-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See `GroupState` for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func          Function to be called on every group.
+   * @param stateEncoder  Encoder for the state type.
+   * @param outputEncoder Encoder for the output type.
+   * @param timeoutConf   Timeout configuration for groups that do not receive data for a while.
+   * @param initialState  The user provided state that will be initialized when the first batch
+   *                      of data is processed in the streaming query. The user defined function
+   *                      will be called on the state data even if there are no other values in
+   *                      the group.
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 3.2.0
+   */
+  def mapGroupsWithState[S, U](
+      func: MapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U],
+      timeoutConf: GroupStateTimeout,
+      initialState: KVDS[K, S]): DS[U] = {
+    val f = ToScalaUDF(func)
+    mapGroupsWithState[S, U](timeoutConf, initialState)(f)(stateEncoder, outputEncoder)
+  }
+
+  /**
+   * (Scala-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See `GroupState` for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func        Function to be called on every group.
+   * @param outputMode  The output mode of the function.
+   * @param timeoutConf Timeout configuration for groups that do not receive data for a while.
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 2.2.0
+   */
+  def flatMapGroupsWithState[S: Encoder, U: Encoder](
+      outputMode: OutputMode,
+      timeoutConf: GroupStateTimeout)(
+      func: (K, Iterator[V], GroupState[S]) => Iterator[U]): DS[U]
+
+  /**
+   * (Scala-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See `GroupState` for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func Function to be called on every group.
+   * @param outputMode The output mode of the function.
+   * @param timeoutConf Timeout configuration for groups that do not receive data for a while.
+   * @param initialState The user provided state that will be initialized when the first batch
+   *                     of data is processed in the streaming query. The user defined function
+   *                     will be called on the state data even if there are no other values in
+   *                     the group. To covert a Dataset `ds` of type  of type `Dataset[(K, S)]`
+   *                     to a `KeyValueGroupedDataset[K, S]`, use
+   *                     {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 3.2.0
+   */
+  def flatMapGroupsWithState[S: Encoder, U: Encoder](
+      outputMode: OutputMode,
+      timeoutConf: GroupStateTimeout,
+      initialState: KVDS[K, S])(
+      func: (K, Iterator[V], GroupState[S]) => Iterator[U]): DS[U]
+
+  /**
+   * (Java-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See `GroupState` for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func          Function to be called on every group.
+   * @param outputMode    The output mode of the function.
+   * @param stateEncoder  Encoder for the state type.
+   * @param outputEncoder Encoder for the output type.
+   * @param timeoutConf   Timeout configuration for groups that do not receive data for a while.
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 2.2.0
+   */
+  def flatMapGroupsWithState[S, U](
+      func: FlatMapGroupsWithStateFunction[K, V, S, U],
+      outputMode: OutputMode,
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U],
+      timeoutConf: GroupStateTimeout): DS[U] = {
+    val f = ToScalaUDF(func)
+    flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder)
+  }
+
+  /**
+   * (Java-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See `GroupState` for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func          Function to be called on every group.
+   * @param outputMode    The output mode of the function.
+   * @param stateEncoder  Encoder for the state type.
+   * @param outputEncoder Encoder for the output type.
+   * @param timeoutConf   Timeout configuration for groups that do not receive data for a while.
+   * @param initialState  The user provided state that will be initialized when the first batch
+   *                      of data is processed in the streaming query. The user defined function
+   *                      will be called on the state data even if there are no other values in
+   *                      the group. To covert a Dataset `ds` of type  of type `Dataset[(K, S)]`
+   *                      to a `KeyValueGroupedDataset[K, S]`, use
+   * {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 3.2.0
+   */
+  def flatMapGroupsWithState[S, U](
+      func: FlatMapGroupsWithStateFunction[K, V, S, U],
+      outputMode: OutputMode,
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U],
+      timeoutConf: GroupStateTimeout,
+      initialState: KVDS[K, S]): DS[U] = {
+    flatMapGroupsWithState[S, U](
+      outputMode,
+      timeoutConf,
+      initialState)(
+      ToScalaUDF(func))(
+      stateEncoder,
+      outputEncoder)
+  }
+
+
+  /**
+   * (Scala-specific)
+   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
+   * We allow the user to act on per-group set of input rows along with keyed state and the
+   * user can choose to output/return 0 or more rows.
+   * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
+   * in each trigger and the user's state/state variables will be stored persistently across
+   * invocations.
+   *
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked
+   *                          by the operator.
+   * @param timeMode          The time mode semantics of the stateful processor for timers and TTL.
+   * @param outputMode        The output mode of the stateful processor.
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   */
+  private[sql] def transformWithState[U: Encoder](
+      statefulProcessor: StatefulProcessor[K, V, U],
+      timeMode: TimeMode,
+      outputMode: OutputMode): DS[U]
+
+  /**
+   * (Scala-specific)
+   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
+   * We allow the user to act on per-group set of input rows along with keyed state and the
+   * user can choose to output/return 0 or more rows.
+   * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
+   * in each trigger and the user's state/state variables will be stored persistently across
+   * invocations.
+   *
+   * Downstream operators would use specified eventTimeColumnName to calculate watermark.
+   * Note that TimeMode is set to EventTime to ensure correct flow of watermark.
+   *
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param statefulProcessor   Instance of statefulProcessor whose functions will
+   *                            be invoked by the operator.
+   * @param eventTimeColumnName eventTime column in the output dataset. Any operations after
+   *                            transformWithState will use the new eventTimeColumn. The user
+   *                            needs to ensure that the eventTime for emitted output adheres to
+   *                            the watermark boundary, otherwise streaming query will fail.
+   * @param outputMode          The output mode of the stateful processor.
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   */
+  private[sql] def transformWithState[U: Encoder](
+      statefulProcessor: StatefulProcessor[K, V, U],
+      eventTimeColumnName: String,
+      outputMode: OutputMode): DS[U]
+
+  /**
+   * (Java-specific)
+   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
+   * We allow the user to act on per-group set of input rows along with keyed state and the
+   * user can choose to output/return 0 or more rows.
+   * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
+   * in each trigger and the user's state/state variables will be stored persistently across
+   * invocations.
+   *
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the
+   *                          operator.
+   * @param timeMode The time mode semantics of the stateful processor for timers and TTL.
+   * @param outputMode The output mode of the stateful processor.
+   * @param outputEncoder Encoder for the output type.
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   */
+  private[sql] def transformWithState[U: Encoder](
+      statefulProcessor: StatefulProcessor[K, V, U],
+      timeMode: TimeMode,
+      outputMode: OutputMode,
+      outputEncoder: Encoder[U]): DS[U] = {
+    transformWithState(statefulProcessor, timeMode, outputMode)(outputEncoder)
+  }
+
+  /**
+   * (Java-specific)
+   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
+   * We allow the user to act on per-group set of input rows along with keyed state and the
+   * user can choose to output/return 0 or more rows.
+   *
+   * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
+   * in each trigger and the user's state/state variables will be stored persistently across
+   * invocations.
+   *
+   * Downstream operators would use specified eventTimeColumnName to calculate watermark.
+   * Note that TimeMode is set to EventTime to ensure correct flow of watermark.
+   *
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the
+   *                          operator.
+   * @param eventTimeColumnName eventTime column in the output dataset. Any operations after
+   *                            transformWithState will use the new eventTimeColumn. The user
+   *                            needs to ensure that the eventTime for emitted output adheres to
+   *                            the watermark boundary, otherwise streaming query will fail.
+   * @param outputMode        The output mode of the stateful processor.
+   * @param outputEncoder     Encoder for the output type.
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   */
+  private[sql] def transformWithState[U: Encoder](
+      statefulProcessor: StatefulProcessor[K, V, U],
+      eventTimeColumnName: String,
+      outputMode: OutputMode,
+      outputEncoder: Encoder[U]): DS[U] = {
+    transformWithState(statefulProcessor, eventTimeColumnName, outputMode)(outputEncoder)
+  }
+
+  /**
+   * (Scala-specific)
+   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
+   * Functions as the function above, but with additional initial state.
+   *
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @tparam S The type of initial state objects. Must be encodable to Spark SQL types.
+   * @param statefulProcessor Instance of statefulProcessor whose functions will
+   *                          be invoked by the operator.
+   * @param timeMode          The time mode semantics of the stateful processor for timers and TTL.
+   * @param outputMode        The output mode of the stateful processor.
+   * @param initialState      User provided initial state that will be used to initiate state for
+   *                          the query in the first batch.
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   */
+  private[sql] def transformWithState[U: Encoder, S: Encoder](
+      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+      timeMode: TimeMode,
+      outputMode: OutputMode,
+      initialState: KVDS[K, S]): DS[U]
+
+  /**
+   * (Scala-specific)
+   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
+   * Functions as the function above, but with additional eventTimeColumnName for output.
+   *
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @tparam S The type of initial state objects. Must be encodable to Spark SQL types.
+   *
+   * Downstream operators would use specified eventTimeColumnName to calculate watermark.
+   * Note that TimeMode is set to EventTime to ensure correct flow of watermark.
+   *
+   * @param statefulProcessor   Instance of statefulProcessor whose functions will
+   *                            be invoked by the operator.
+   * @param eventTimeColumnName eventTime column in the output dataset. Any operations after
+   *                            transformWithState will use the new eventTimeColumn. The user
+   *                            needs to ensure that the eventTime for emitted output adheres to
+   *                            the watermark boundary, otherwise streaming query will fail.
+   * @param outputMode          The output mode of the stateful processor.
+   * @param initialState        User provided initial state that will be used to initiate state for
+   *                            the query in the first batch.
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   */
+  private[sql] def transformWithState[U: Encoder, S: Encoder](
+      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+      eventTimeColumnName: String,
+      outputMode: OutputMode,
+      initialState: KVDS[K, S]): DS[U]
+
+  /**
+   * (Java-specific)
+   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
+   * Functions as the function above, but with additional initialStateEncoder for state encoding.
+   *
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @tparam S The type of initial state objects. Must be encodable to Spark SQL types.
+   * @param statefulProcessor   Instance of statefulProcessor whose functions will
+   *                            be invoked by the operator.
+   * @param timeMode            The time mode semantics of the stateful processor for
+   *                            timers and TTL.
+   * @param outputMode          The output mode of the stateful processor.
+   * @param initialState        User provided initial state that will be used to initiate state for
+   *                            the query in the first batch.
+   * @param outputEncoder       Encoder for the output type.
+   * @param initialStateEncoder Encoder for the initial state type.
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   */
+  private[sql] def transformWithState[U: Encoder, S: Encoder](
+      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+      timeMode: TimeMode,
+      outputMode: OutputMode,
+      initialState: KVDS[K, S],
+      outputEncoder: Encoder[U],
+      initialStateEncoder: Encoder[S]): DS[U] = {
+    transformWithState(statefulProcessor, timeMode,
+      outputMode, initialState)(outputEncoder, initialStateEncoder)
+  }
+
+  /**
+   * (Java-specific)
+   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
+   * Functions as the function above, but with additional eventTimeColumnName for output.
+   *
+   * Downstream operators would use specified eventTimeColumnName to calculate watermark.
+   * Note that TimeMode is set to EventTime to ensure correct flow of watermark.
+   *
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @tparam S The type of initial state objects. Must be encodable to Spark SQL types.
+   * @param statefulProcessor Instance of statefulProcessor whose functions will
+   *                          be invoked by the operator.
+   * @param outputMode        The output mode of the stateful processor.
+   * @param initialState      User provided initial state that will be used to initiate state for
+   *                          the query in the first batch.
+   * @param eventTimeColumnName event column in the output dataset. Any operations after
+   *                            transformWithState will use the new eventTimeColumn. The user
+   *                            needs to ensure that the eventTime for emitted output adheres to
+   *                            the watermark boundary, otherwise streaming query will fail.
+   * @param outputEncoder     Encoder for the output type.
+   * @param initialStateEncoder Encoder for the initial state type.
+   *
+   * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark SQL.
+   */
+  private[sql] def transformWithState[U: Encoder, S: Encoder](
+      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+      outputMode: OutputMode,
+      initialState: KVDS[K, S],
+      eventTimeColumnName: String,
+      outputEncoder: Encoder[U],
+      initialStateEncoder: Encoder[S]): DS[U] = {
+    transformWithState(statefulProcessor, eventTimeColumnName,
+      outputMode, initialState)(outputEncoder, initialStateEncoder)
+  }
+  /**
+   * (Scala-specific)
+   * Reduces the elements of each group of data using the specified binary function.
+   * The given function must be commutative and associative or the result may be non-deterministic.
+   *
+   * @since 1.6.0
+   */
+  def reduceGroups(f: (V, V) => V): DS[(K, V)]
+
+  /**
+   * (Java-specific)
+   * Reduces the elements of each group of data using the specified binary function.
+   * The given function must be commutative and associative or the result may be non-deterministic.
+   *
+   * @since 1.6.0
+   */
+  def reduceGroups(f: ReduceFunction[V]): DS[(K, V)] = {
+    reduceGroups(ToScalaUDF(f))
+  }
+
+  /**
+   * Internal helper function for building typed aggregations that return tuples.  For simplicity
+   * and code reuse, we do this without the help of the type system and then use helper functions
+   * that cast appropriately for the user facing interface.
+   */
+  protected def aggUntyped(columns: TypedColumn[_, _]*): DS[_]
+
+  /**
+   * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key
+   * and the result of computing this aggregation over all elements in the group.
+   *
+   * @since 1.6.0
+   */
+  def agg[U1](col1: TypedColumn[V, U1]): DS[(K, U1)] =
+    aggUntyped(col1).asInstanceOf[DS[(K, U1)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
+   * and the result of computing these aggregations over all elements in the group.
+   *
+   * @since 1.6.0
+   */
+  def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): DS[(K, U1, U2)] =
+    aggUntyped(col1, col2).asInstanceOf[DS[(K, U1, U2)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
+   * and the result of computing these aggregations over all elements in the group.
+   *
+   * @since 1.6.0
+   */
+  def agg[U1, U2, U3](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3]): DS[(K, U1, U2, U3)] =
+    aggUntyped(col1, col2, col3).asInstanceOf[DS[(K, U1, U2, U3)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
+   * and the result of computing these aggregations over all elements in the group.
+   *
+   * @since 1.6.0
+   */
+  def agg[U1, U2, U3, U4](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4]): DS[(K, U1, U2, U3, U4)] =
+    aggUntyped(col1, col2, col3, col4).asInstanceOf[DS[(K, U1, U2, U3, U4)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
+   * and the result of computing these aggregations over all elements in the group.
+   *
+   * @since 3.0.0
+   */
+  def agg[U1, U2, U3, U4, U5](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5]): DS[(K, U1, U2, U3, U4, U5)] =
+    aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[DS[(K, U1, U2, U3, U4, U5)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
+   * and the result of computing these aggregations over all elements in the group.
+   *
+   * @since 3.0.0
+   */
+  def agg[U1, U2, U3, U4, U5, U6](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5],
+      col6: TypedColumn[V, U6]): DS[(K, U1, U2, U3, U4, U5, U6)] =
+    aggUntyped(col1, col2, col3, col4, col5, col6)
+      .asInstanceOf[DS[(K, U1, U2, U3, U4, U5, U6)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
+   * and the result of computing these aggregations over all elements in the group.
+   *
+   * @since 3.0.0
+   */
+  def agg[U1, U2, U3, U4, U5, U6, U7](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5],
+      col6: TypedColumn[V, U6],
+      col7: TypedColumn[V, U7]): DS[(K, U1, U2, U3, U4, U5, U6, U7)] =
+    aggUntyped(col1, col2, col3, col4, col5, col6, col7)
+      .asInstanceOf[DS[(K, U1, U2, U3, U4, U5, U6, U7)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
+   * and the result of computing these aggregations over all elements in the group.
+   *
+   * @since 3.0.0
+   */
+  def agg[U1, U2, U3, U4, U5, U6, U7, U8](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5],
+      col6: TypedColumn[V, U6],
+      col7: TypedColumn[V, U7],
+      col8: TypedColumn[V, U8]): DS[(K, U1, U2, U3, U4, U5, U6, U7, U8)] =
+    aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8)
+      .asInstanceOf[DS[(K, U1, U2, U3, U4, U5, U6, U7, U8)]]
+
+  /**
+   * Returns a [[Dataset]] that contains a tuple with each key and the number of items present
+   * for that key.
+   *
+   * @since 1.6.0
+   */
+  def count(): DS[(K, Long)] = agg(cnt(lit(1)).as(PrimitiveLongEncoder))
+
+  /**
+   * (Scala-specific)
+   * Applies the given function to each cogrouped data.  For each unique group, the function will
+   * be passed the grouping key and 2 iterators containing all elements in the group from
+   * [[Dataset]] `this` and `other`.  The function can return an iterator containing elements of an
+   * arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 1.6.0
+   */
+  def cogroup[U, R: Encoder](
+      other: KVDS[K, U])(
+      f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): DS[R] = {
+    cogroupSorted(other)(Nil: _*)(Nil: _*)(f)
+  }
+
+  /**
+   * (Java-specific)
+   * Applies the given function to each cogrouped data.  For each unique group, the function will
+   * be passed the grouping key and 2 iterators containing all elements in the group from
+   * [[Dataset]] `this` and `other`.  The function can return an iterator containing elements of an
+   * arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 1.6.0
+   */
+  def cogroup[U, R](
+      other: KVDS[K, U],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): DS[R] = {
+    cogroup(other)(ToScalaUDF(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific)
+   * Applies the given function to each sorted cogrouped data.  For each unique group, the function
+   * will be passed the grouping key and 2 sorted iterators containing all elements in the group
+   * from [[Dataset]] `this` and `other`.  The function can return an iterator containing elements
+   * of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators
+   * to be sorted according to the given sort expressions. That sorting does not add
+   * computational complexity.
+   *
+   * @see [[org.apache.spark.sql.api.KeyValueGroupedDataset#cogroup]]
+   * @since 3.4.0
+   */
+  def cogroupSorted[U, R : Encoder](
+      other: KVDS[K, U])(
+      thisSortExprs: Column*)(
+      otherSortExprs: Column*)(
+      f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): DS[R]
+
+  /**
+   * (Java-specific)
+   * Applies the given function to each sorted cogrouped data.  For each unique group, the function
+   * will be passed the grouping key and 2 sorted iterators containing all elements in the group
+   * from [[Dataset]] `this` and `other`.  The function can return an iterator containing elements
+   * of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators
+   * to be sorted according to the given sort expressions. That sorting does not add
+   * computational complexity.
+   *
+   * @see [[org.apache.spark.sql.api.KeyValueGroupedDataset#cogroup]]
+   * @since 3.4.0
+   */
+  def cogroupSorted[U, R](
+      other: KVDS[K, U],
+      thisSortExprs: Array[Column],
+      otherSortExprs: Array[Column],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): DS[R] = {
+    import org.apache.spark.util.ArrayImplicits._
+    cogroupSorted(other)(
+      thisSortExprs.toImmutableArraySeq: _*)(otherSortExprs.toImmutableArraySeq: _*)(
+      ToScalaUDF(f))(encoder)
+  }
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala
index 30b2992..35d6d13 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala
@@ -21,7 +21,7 @@
 import _root_.java.util
 
 import org.apache.spark.annotation.Stable
-import org.apache.spark.sql.{functions, Column, Row}
+import org.apache.spark.sql.{functions, Column, Encoder, Row}
 
 /**
  * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]],
@@ -66,6 +66,14 @@
   }
 
   /**
+   * Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions
+   * of current `RelationalGroupedDataset`.
+   *
+   * @since 3.0.0
+   */
+  def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T, DS]
+
+  /**
    * (Scala-specific) Compute aggregates by specifying the column names and
    * aggregate methods. The resulting `DataFrame` will also contain the grouping columns.
    *
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala
index 25ea37f..66ea50c 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala
@@ -16,14 +16,55 @@
  */
 package org.apache.spark.sql.internal
 
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.api.java.function.{CoGroupFunction, FilterFunction, FlatMapFunction, FlatMapGroupsFunction, FlatMapGroupsWithStateFunction, ForeachFunction, ForeachPartitionFunction, MapFunction, MapGroupsFunction, MapGroupsWithStateFunction, MapPartitionsFunction, ReduceFunction}
 import org.apache.spark.sql.api.java._
+import org.apache.spark.sql.streaming.GroupState
 
 /**
- * Helper class that provides conversions from org.apache.spark.sql.api.java.Function* to
- * scala.Function*.
+ * Helper class that provides conversions from org.apache.spark.sql.api.java.Function* and
+ * org.apache.spark.api.java.function.* to scala functions.
+ *
+ * Please note that this class is being used in Spark Connect Scala UDFs. We need to be careful
+ * with any modifications to this class, otherwise we will break backwards compatibility. Concretely
+ * this means you can only add methods to this class. You cannot rename the class, move it, change
+ * its `serialVersionUID`, remove methods, change method signatures, or change method semantics.
  */
 @SerialVersionUID(2019907615267866045L)
 private[sql] object ToScalaUDF extends Serializable {
+  def apply[T](f: FilterFunction[T]): T => Boolean = f.call
+
+  def apply[T](f: ReduceFunction[T]): (T, T) => T = f.call
+
+  def apply[V, W](f: MapFunction[V, W]): V => W = f.call
+
+  def apply[K, V, U](f: MapGroupsFunction[K, V, U]): (K, Iterator[V]) => U =
+    (key, values) => f.call(key, values.asJava)
+
+  def apply[K, V, S, U](f: MapGroupsWithStateFunction[K, V, S, U])
+    : (K, Iterator[V], GroupState[S]) => U =
+    (key, values, state) => f.call(key, values.asJava, state)
+
+  def apply[V, U](f: MapPartitionsFunction[V, U]): Iterator[V] => Iterator[U] =
+    values => f.call(values.asJava).asScala
+
+  def apply[K, V, U](f: FlatMapGroupsFunction[K, V, U]): (K, Iterator[V]) => Iterator[U] =
+    (key, values) => f.call(key, values.asJava).asScala
+
+  def apply[K, V, S, U](f: FlatMapGroupsWithStateFunction[K, V, S, U])
+    : (K, Iterator[V], GroupState[S]) => Iterator[U] =
+    (key, values, state) => f.call(key, values.asJava, state).asScala
+
+  def apply[K, V, U, R](f: CoGroupFunction[K, V, U, R])
+    : (K, Iterator[V], Iterator[U]) => Iterator[R] =
+    (key, left, right) => f.call(key, left.asJava, right.asJava).asScala
+
+  def apply[V](f: ForeachFunction[V]): V => Unit = f.call
+
+  def apply[V](f: ForeachPartitionFunction[V]): Iterator[V] => Unit =
+    values => f.call(values.asJava)
+
   // scalastyle:off line.size.limit
 
   /* register 0-22 were generated by this script
@@ -38,7 +79,7 @@
         |/**
         | * Create a scala.Function$i wrapper for a org.apache.spark.sql.api.java.UDF$i instance.
         | */
-        |def apply(f: UDF$i[$extTypeArgs]): AnyRef = {
+        |def apply(f: UDF$i[$extTypeArgs]): Function$i[$anyTypeArgs] = {
         |  $funcCall
         |}""".stripMargin)
     }
@@ -47,162 +88,239 @@
   /**
    * Create a scala.Function0 wrapper for a org.apache.spark.sql.api.java.UDF0 instance.
    */
-  def apply(f: UDF0[_]): AnyRef = {
+  def apply(f: UDF0[_]): () => Any = {
     () => f.asInstanceOf[UDF0[Any]].call()
   }
 
   /**
    * Create a scala.Function1 wrapper for a org.apache.spark.sql.api.java.UDF1 instance.
    */
-  def apply(f: UDF1[_, _]): AnyRef = {
+  def apply(f: UDF1[_, _]): (Any) => Any = {
     f.asInstanceOf[UDF1[Any, Any]].call(_: Any)
   }
 
   /**
    * Create a scala.Function2 wrapper for a org.apache.spark.sql.api.java.UDF2 instance.
    */
-  def apply(f: UDF2[_, _, _]): AnyRef = {
+  def apply(f: UDF2[_, _, _]): (Any, Any) => Any = {
     f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
   }
 
   /**
    * Create a scala.Function3 wrapper for a org.apache.spark.sql.api.java.UDF3 instance.
    */
-  def apply(f: UDF3[_, _, _, _]): AnyRef = {
+  def apply(f: UDF3[_, _, _, _]): (Any, Any, Any) => Any = {
     f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function4 wrapper for a org.apache.spark.sql.api.java.UDF4 instance.
    */
-  def apply(f: UDF4[_, _, _, _, _]): AnyRef = {
+  def apply(f: UDF4[_, _, _, _, _]): (Any, Any, Any, Any) => Any = {
     f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function5 wrapper for a org.apache.spark.sql.api.java.UDF5 instance.
    */
-  def apply(f: UDF5[_, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF5[_, _, _, _, _, _]): (Any, Any, Any, Any, Any) => Any = {
     f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function6 wrapper for a org.apache.spark.sql.api.java.UDF6 instance.
    */
-  def apply(f: UDF6[_, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF6[_, _, _, _, _, _, _]): (Any, Any, Any, Any, Any, Any) => Any = {
     f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function7 wrapper for a org.apache.spark.sql.api.java.UDF7 instance.
    */
-  def apply(f: UDF7[_, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF7[_, _, _, _, _, _, _, _]): (Any, Any, Any, Any, Any, Any, Any) => Any = {
     f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function8 wrapper for a org.apache.spark.sql.api.java.UDF8 instance.
    */
-  def apply(f: UDF8[_, _, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF8[_, _, _, _, _, _, _, _, _]): (Any, Any, Any, Any, Any, Any, Any, Any) => Any = {
     f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function9 wrapper for a org.apache.spark.sql.api.java.UDF9 instance.
    */
-  def apply(f: UDF9[_, _, _, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF9[_, _, _, _, _, _, _, _, _, _]): (Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any = {
     f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function10 wrapper for a org.apache.spark.sql.api.java.UDF10 instance.
    */
-  def apply(f: UDF10[_, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF10[_, _, _, _, _, _, _, _, _, _, _]): Function10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
     f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function11 wrapper for a org.apache.spark.sql.api.java.UDF11 instance.
    */
-  def apply(f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _]): Function11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
     f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function12 wrapper for a org.apache.spark.sql.api.java.UDF12 instance.
    */
-  def apply(f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]): Function12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
     f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function13 wrapper for a org.apache.spark.sql.api.java.UDF13 instance.
    */
-  def apply(f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
     f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function14 wrapper for a org.apache.spark.sql.api.java.UDF14 instance.
    */
-  def apply(f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
     f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function15 wrapper for a org.apache.spark.sql.api.java.UDF15 instance.
    */
-  def apply(f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
     f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function16 wrapper for a org.apache.spark.sql.api.java.UDF16 instance.
    */
-  def apply(f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
     f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function17 wrapper for a org.apache.spark.sql.api.java.UDF17 instance.
    */
-  def apply(f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
     f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function18 wrapper for a org.apache.spark.sql.api.java.UDF18 instance.
    */
-  def apply(f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
     f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function19 wrapper for a org.apache.spark.sql.api.java.UDF19 instance.
    */
-  def apply(f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
     f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function20 wrapper for a org.apache.spark.sql.api.java.UDF20 instance.
    */
-  def apply(f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
     f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function21 wrapper for a org.apache.spark.sql.api.java.UDF21 instance.
    */
-  def apply(f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
     f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
 
   /**
    * Create a scala.Function22 wrapper for a org.apache.spark.sql.api.java.UDF22 instance.
    */
-  def apply(f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
+  def apply(f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
     f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
   }
   // scalastyle:on line.size.limit
 }
+
+/**
+ * Adaptors from one UDF shape to another. For example adapting a foreach function for use in
+ * foreachPartition.
+ *
+ * Please note that this class is being used in Spark Connect Scala UDFs. We need to be careful
+ * with any modifications to this class, otherwise we will break backwards compatibility. Concretely
+ * this means you can only add methods to this class. You cannot rename the class, move it, change
+ * its `serialVersionUID`, remove methods, change method signatures, or change method semantics.
+ */
+@SerialVersionUID(0L) // TODO
+object UDFAdaptors extends Serializable {
+  def flatMapToMapPartitions[V, U](f: V => IterableOnce[U]): Iterator[V] => Iterator[U] =
+    values => values.flatMap(f)
+
+  def flatMapToMapPartitions[V, U](f: FlatMapFunction[V, U]): Iterator[V] => Iterator[U] =
+    values => values.flatMap(v => f.call(v).asScala)
+
+  def mapToMapPartitions[V, U](f: V => U): Iterator[V] => Iterator[U] = values => values.map(f)
+
+  def mapToMapPartitions[V, U](f: MapFunction[V, U]): Iterator[V] => Iterator[U] =
+    values => values.map(f.call)
+
+  def foreachToForeachPartition[T](f: T => Unit): Iterator[T] => Unit =
+    values => values.foreach(f)
+
+  def foreachToForeachPartition[T](f: ForeachFunction[T]): Iterator[T] => Unit =
+    values => values.foreach(f.call)
+
+  def foreachPartitionToMapPartitions[V, U](f: Iterator[V] => Unit): Iterator[V] => Iterator[U] =
+    values => {
+      f(values)
+      Iterator.empty[U]
+    }
+
+  def iterableOnceToSeq[A, B](f: A => IterableOnce[B]): A => Seq[B] =
+    value => f(value).iterator.toSeq
+
+  def mapGroupsToFlatMapGroups[K, V, U](f: (K, Iterator[V]) => U): (K, Iterator[V]) => Iterator[U] =
+    (key, values) => Iterator.single(f(key, values))
+
+  def mapGroupsWithStateToFlatMapWithState[K, V, S, U](
+      f: (K, Iterator[V], GroupState[S]) => U): (K, Iterator[V], GroupState[S]) => Iterator[U] =
+    (key: K, values: Iterator[V], state: GroupState[S]) => Iterator(f(key, values, state))
+
+  def coGroupWithMappedValues[K, V, U, R, IV, IU](
+      f: (K, Iterator[V], Iterator[U]) => IterableOnce[R],
+      leftValueMapFunc: Option[IV => V],
+      rightValueMapFunc: Option[IU => U]): (K, Iterator[IV], Iterator[IU]) => IterableOnce[R] = {
+    (leftValueMapFunc, rightValueMapFunc) match {
+      case (None, None) =>
+        f.asInstanceOf[(K, Iterator[IV], Iterator[IU]) => IterableOnce[R]]
+      case (Some(mapLeft), None) =>
+        (k, left, right) => f(k, left.map(mapLeft), right.asInstanceOf[Iterator[U]])
+      case (None, Some(mapRight)) =>
+        (k, left, right) => f(k, left.asInstanceOf[Iterator[V]], right.map(mapRight))
+      case (Some(mapLeft), Some(mapRight)) =>
+        (k, left, right) => f(k, left.map(mapLeft), right.map(mapRight))
+    }
+  }
+
+  def flatMapGroupsWithMappedValues[K, IV, V, R](
+     f: (K, Iterator[V]) => IterableOnce[R],
+     valueMapFunc: Option[IV => V]): (K, Iterator[IV]) => IterableOnce[R] = valueMapFunc match {
+    case Some(mapValue) => (k, values) => f(k, values.map(mapValue))
+    case None => f.asInstanceOf[(K, Iterator[IV]) => IterableOnce[R]]
+  }
+
+  def flatMapGroupsWithStateWithMappedValues[K, IV, V, S, U](
+      f: (K, Iterator[V], GroupState[S]) => Iterator[U],
+      valueMapFunc: Option[IV => V]): (K, Iterator[IV], GroupState[S]) => Iterator[U] = {
+    valueMapFunc match {
+      case Some(mapValue) => (k, values, state) => f(k, values.map(mapValue), state)
+      case None => f.asInstanceOf[(K, Iterator[IV], GroupState[S]) => Iterator[U]]
+    }
+  }
+}
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
index d77b4b8..2dba8fc 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
@@ -28,6 +28,9 @@
  * mapPartitions etc. This class is shared between the client and the server so that when the
  * methods are used in client UDFs, the server will be able to find them when actually executing
  * the UDFs.
+ *
+ * DO NOT REMOVE/CHANGE THIS OBJECT OR ANY OF ITS METHODS, THEY ARE NEEDED FOR BACKWARDS
+ * COMPATIBILITY WITH OLDER CLIENTS!
  */
 @SerialVersionUID(8464839273647598302L)
 private[sql] object UdfUtils extends Serializable {
@@ -137,8 +140,6 @@
 
   // ----------------------------------------------------------------------------------------------
   // Scala Functions wrappers for java UDFs.
-  //
-  // DO NOT REMOVE THESE, THEY ARE NEEDED FOR BACKWARDS COMPATIBILITY WITH OLDER CLIENTS!
   // ----------------------------------------------------------------------------------------------
   //  (1 to 22).foreach { i =>
   //    val extTypeArgs = (0 to i).map(_ => "_").mkString(", ")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 38521e8..9ae89e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -60,7 +60,7 @@
 import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable}
 import org.apache.spark.sql.execution.python.EvaluatePython
 import org.apache.spark.sql.execution.stat.StatFunctions
-import org.apache.spark.sql.internal.{DataFrameWriterImpl, SQLConf}
+import org.apache.spark.sql.internal.{DataFrameWriterImpl, SQLConf, ToScalaUDF}
 import org.apache.spark.sql.internal.ExpressionUtils.column
 import org.apache.spark.sql.internal.TypedAggUtils.withInputType
 import org.apache.spark.sql.streaming.DataStreamWriter
@@ -927,7 +927,7 @@
    * @since 2.0.0
    */
   def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
-    groupByKey(func.call(_))(encoder)
+    groupByKey(ToScalaUDF(func))(encoder)
 
   /** @inheritdoc */
   def unpivot(
@@ -1362,12 +1362,6 @@
       implicitly[Encoder[U]])
   }
 
-  /** @inheritdoc */
-  def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
-    val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala
-    mapPartitions(func)(encoder)
-  }
-
   /**
    * Returns a new `DataFrame` that contains the result of applying a serialized R function
    * `func` to each partition.
@@ -1427,11 +1421,6 @@
   }
 
   /** @inheritdoc */
-  def foreach(f: T => Unit): Unit = withNewRDDExecutionId("foreach") {
-    rdd.foreach(f)
-  }
-
-  /** @inheritdoc */
   def foreachPartition(f: Iterator[T] => Unit): Unit = withNewRDDExecutionId("foreachPartition") {
     rdd.foreachPartition(f)
   }
@@ -1953,6 +1942,10 @@
     super.dropDuplicatesWithinWatermark(col1, cols: _*)
 
   /** @inheritdoc */
+  override def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] =
+    super.mapPartitions(f, encoder)
+
+  /** @inheritdoc */
   override def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] = super.flatMap(func)
 
   /** @inheritdoc */
@@ -1960,9 +1953,6 @@
     super.flatMap(f, encoder)
 
   /** @inheritdoc */
-  override def foreach(func: ForeachFunction[T]): Unit = super.foreach(func)
-
-  /** @inheritdoc */
   override def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
     super.foreachPartition(func)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index e3ea33a..1ebdd57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -17,12 +17,10 @@
 
 package org.apache.spark.sql
 
-import scala.jdk.CollectionConverters._
-
 import org.apache.spark.api.java.function._
 import org.apache.spark.sql.catalyst.analysis.{EliminateEventTimeWatermark, UnresolvedAttribute}
 import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.expressions.ReduceAggregator
@@ -41,7 +39,9 @@
     vEncoder: Encoder[V],
     @transient val queryExecution: QueryExecution,
     private val dataAttributes: Seq[Attribute],
-    private val groupingAttributes: Seq[Attribute]) extends Serializable {
+    private val groupingAttributes: Seq[Attribute])
+  extends api.KeyValueGroupedDataset[K, V, Dataset] {
+  type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL]
 
   // Similar to [[Dataset]], we turn the passed in encoder to `ExpressionEncoder` explicitly.
   private implicit val kExprEnc: ExpressionEncoder[K] = encoderFor(kEncoder)
@@ -51,13 +51,7 @@
   private def sparkSession = queryExecution.sparkSession
   import queryExecution.sparkSession._
 
-  /**
-   * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the
-   * specified type. The mapping of key columns to the type follows the same rules as `as` on
-   * [[Dataset]].
-   *
-   * @since 1.6.0
-   */
+  /** @inheritdoc */
   def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] =
     new KeyValueGroupedDataset(
       encoderFor[L],
@@ -66,17 +60,7 @@
       dataAttributes,
       groupingAttributes)
 
-  /**
-   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied
-   * to the data. The grouping key is unchanged by this.
-   *
-   * {{{
-   *   // Create values grouped by key from a Dataset[(K, V)]
-   *   ds.groupByKey(_._1).mapValues(_._2) // Scala
-   * }}}
-   *
-   * @since 2.1.0
-   */
+  /** @inheritdoc */
   def mapValues[W : Encoder](func: V => W): KeyValueGroupedDataset[K, W] = {
     val withNewData = AppendColumns(func, dataAttributes, logicalPlan)
     val projected = Project(withNewData.newColumns ++ groupingAttributes, withNewData)
@@ -90,30 +74,7 @@
       groupingAttributes)
   }
 
-  /**
-   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied
-   * to the data. The grouping key is unchanged by this.
-   *
-   * {{{
-   *   // Create Integer values grouped by String key from a Dataset<Tuple2<String, Integer>>
-   *   Dataset<Tuple2<String, Integer>> ds = ...;
-   *   KeyValueGroupedDataset<String, Integer> grouped =
-   *     ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT());
-   * }}}
-   *
-   * @since 2.1.0
-   */
-  def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = {
-    implicit val uEnc = encoder
-    mapValues { (v: V) => func.call(v) }
-  }
-
-  /**
-   * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping
-   * over the Dataset to extract the keys and then running a distinct operation on those.
-   *
-   * @since 1.6.0
-   */
+  /** @inheritdoc */
   def keys: Dataset[K] = {
     Dataset[K](
       sparkSession,
@@ -121,194 +82,23 @@
         Project(groupingAttributes, logicalPlan)))
   }
 
-  /**
-   * (Scala-specific)
-   * Applies the given function to each group of data.  For each unique group, the function will
-   * be passed the group key and an iterator that contains all of the elements in the group. The
-   * function can return an iterator containing elements of an arbitrary type which will be returned
-   * as a new [[Dataset]].
-   *
-   * This function does not support partial aggregation, and as a result requires shuffling all
-   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
-   * key, it is best to use the reduce function or an
-   * `org.apache.spark.sql.expressions#Aggregator`.
-   *
-   * Internally, the implementation will spill to disk if any given group is too large to fit into
-   * memory.  However, users must take care to avoid materializing the whole iterator for a group
-   * (for example, by calling `toList`) unless they are sure that this is possible given the memory
-   * constraints of their cluster.
-   *
-   * @since 1.6.0
-   */
-  def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = {
-    Dataset[U](
-      sparkSession,
-      MapGroups(
-        f,
-        groupingAttributes,
-        dataAttributes,
-        Seq.empty,
-        logicalPlan))
-  }
-
-  /**
-   * (Java-specific)
-   * Applies the given function to each group of data.  For each unique group, the function will
-   * be passed the group key and an iterator that contains all of the elements in the group. The
-   * function can return an iterator containing elements of an arbitrary type which will be returned
-   * as a new [[Dataset]].
-   *
-   * This function does not support partial aggregation, and as a result requires shuffling all
-   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
-   * key, it is best to use the reduce function or an
-   * `org.apache.spark.sql.expressions#Aggregator`.
-   *
-   * Internally, the implementation will spill to disk if any given group is too large to fit into
-   * memory.  However, users must take care to avoid materializing the whole iterator for a group
-   * (for example, by calling `toList`) unless they are sure that this is possible given the memory
-   * constraints of their cluster.
-   *
-   * @since 1.6.0
-   */
-  def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
-    flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder)
-  }
-
-  /**
-   * (Scala-specific)
-   * Applies the given function to each group of data.  For each unique group, the function will
-   * be passed the group key and a sorted iterator that contains all of the elements in the group.
-   * The function can return an iterator containing elements of an arbitrary type which will be
-   * returned as a new [[Dataset]].
-   *
-   * This function does not support partial aggregation, and as a result requires shuffling all
-   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
-   * key, it is best to use the reduce function or an
-   * `org.apache.spark.sql.expressions#Aggregator`.
-   *
-   * Internally, the implementation will spill to disk if any given group is too large to fit into
-   * memory.  However, users must take care to avoid materializing the whole iterator for a group
-   * (for example, by calling `toList`) unless they are sure that this is possible given the memory
-   * constraints of their cluster.
-   *
-   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator
-   * to be sorted according to the given sort expressions. That sorting does not add
-   * computational complexity.
-   *
-   * @see [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
-   * @since 3.4.0
-   */
+  /** @inheritdoc */
   def flatMapSortedGroups[U : Encoder](
       sortExprs: Column*)(
       f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = {
-    val sortOrder: Seq[SortOrder] = MapGroups.sortOrder(sortExprs.map(_.expr))
-
     Dataset[U](
       sparkSession,
       MapGroups(
         f,
         groupingAttributes,
         dataAttributes,
-        sortOrder,
+        MapGroups.sortOrder(sortExprs.map(_.expr)),
         logicalPlan
       )
     )
   }
 
-  /**
-   * (Java-specific)
-   * Applies the given function to each group of data.  For each unique group, the function will
-   * be passed the group key and a sorted iterator that contains all of the elements in the group.
-   * The function can return an iterator containing elements of an arbitrary type which will be
-   * returned as a new [[Dataset]].
-   *
-   * This function does not support partial aggregation, and as a result requires shuffling all
-   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
-   * key, it is best to use the reduce function or an
-   * `org.apache.spark.sql.expressions#Aggregator`.
-   *
-   * Internally, the implementation will spill to disk if any given group is too large to fit into
-   * memory.  However, users must take care to avoid materializing the whole iterator for a group
-   * (for example, by calling `toList`) unless they are sure that this is possible given the memory
-   * constraints of their cluster.
-   *
-   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator
-   * to be sorted according to the given sort expressions. That sorting does not add
-   * computational complexity.
-   *
-   * @see [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
-   * @since 3.4.0
-   */
-  def flatMapSortedGroups[U](
-      SortExprs: Array[Column],
-      f: FlatMapGroupsFunction[K, V, U],
-      encoder: Encoder[U]): Dataset[U] = {
-    import org.apache.spark.util.ArrayImplicits._
-    flatMapSortedGroups(
-      SortExprs.toImmutableArraySeq: _*)((key, data) => f.call(key, data.asJava).asScala)(encoder)
-  }
-
-  /**
-   * (Scala-specific)
-   * Applies the given function to each group of data.  For each unique group, the function will
-   * be passed the group key and an iterator that contains all of the elements in the group. The
-   * function can return an element of arbitrary type which will be returned as a new [[Dataset]].
-   *
-   * This function does not support partial aggregation, and as a result requires shuffling all
-   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
-   * key, it is best to use the reduce function or an
-   * `org.apache.spark.sql.expressions#Aggregator`.
-   *
-   * Internally, the implementation will spill to disk if any given group is too large to fit into
-   * memory.  However, users must take care to avoid materializing the whole iterator for a group
-   * (for example, by calling `toList`) unless they are sure that this is possible given the memory
-   * constraints of their cluster.
-   *
-   * @since 1.6.0
-   */
-  def mapGroups[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
-    val func = (key: K, it: Iterator[V]) => Iterator(f(key, it))
-    flatMapGroups(func)
-  }
-
-  /**
-   * (Java-specific)
-   * Applies the given function to each group of data.  For each unique group, the function will
-   * be passed the group key and an iterator that contains all of the elements in the group. The
-   * function can return an element of arbitrary type which will be returned as a new [[Dataset]].
-   *
-   * This function does not support partial aggregation, and as a result requires shuffling all
-   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
-   * key, it is best to use the reduce function or an
-   * `org.apache.spark.sql.expressions#Aggregator`.
-   *
-   * Internally, the implementation will spill to disk if any given group is too large to fit into
-   * memory.  However, users must take care to avoid materializing the whole iterator for a group
-   * (for example, by calling `toList`) unless they are sure that this is possible given the memory
-   * constraints of their cluster.
-   *
-   * @since 1.6.0
-   */
-  def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
-    mapGroups((key, data) => f.call(key, data.asJava))(encoder)
-  }
-
-  /**
-   * (Scala-specific)
-   * Applies the given function to each group of data, while maintaining a user-defined per-group
-   * state. The result Dataset will represent the objects returned by the function.
-   * For a static batch Dataset, the function will be invoked once per group. For a streaming
-   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
-   * updates to each group's state will be saved across invocations.
-   * See [[org.apache.spark.sql.streaming.GroupState]] for more details.
-   *
-   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func Function to be called on every group.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 2.2.0
-   */
+  /** @inheritdoc */
   def mapGroupsWithState[S: Encoder, U: Encoder](
       func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
     val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s))
@@ -324,23 +114,7 @@
         child = logicalPlan))
   }
 
-  /**
-   * (Scala-specific)
-   * Applies the given function to each group of data, while maintaining a user-defined per-group
-   * state. The result Dataset will represent the objects returned by the function.
-   * For a static batch Dataset, the function will be invoked once per group. For a streaming
-   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
-   * updates to each group's state will be saved across invocations.
-   * See [[org.apache.spark.sql.streaming.GroupState]] for more details.
-   *
-   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func Function to be called on every group.
-   * @param timeoutConf Timeout configuration for groups that do not receive data for a while.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 2.2.0
-   */
+  /** @inheritdoc */
   def mapGroupsWithState[S: Encoder, U: Encoder](
       timeoutConf: GroupStateTimeout)(
       func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
@@ -357,29 +131,7 @@
         child = logicalPlan))
   }
 
-  /**
-   * (Scala-specific)
-   * Applies the given function to each group of data, while maintaining a user-defined per-group
-   * state. The result Dataset will represent the objects returned by the function.
-   * For a static batch Dataset, the function will be invoked once per group. For a streaming
-   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
-   * updates to each group's state will be saved across invocations.
-   * See [[org.apache.spark.sql.streaming.GroupState]] for more details.
-   *
-   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func Function to be called on every group.
-   * @param timeoutConf Timeout Conf, see GroupStateTimeout for more details
-   * @param initialState The user provided state that will be initialized when the first batch
-   *                     of data is processed in the streaming query. The user defined function
-   *                     will be called on the state data even if there are no other values in
-   *                     the group. To convert a Dataset ds of type Dataset[(K, S)] to a
-   *                     KeyValueGroupedDataset[K, S]
-   *                     do {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 3.2.0
-   */
+  /** @inheritdoc */
   def mapGroupsWithState[S: Encoder, U: Encoder](
       timeoutConf: GroupStateTimeout,
       initialState: KeyValueGroupedDataset[K, S])(
@@ -402,114 +154,7 @@
       ))
   }
 
-  /**
-   * (Java-specific)
-   * Applies the given function to each group of data, while maintaining a user-defined per-group
-   * state. The result Dataset will represent the objects returned by the function.
-   * For a static batch Dataset, the function will be invoked once per group. For a streaming
-   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
-   * updates to each group's state will be saved across invocations.
-   * See `GroupState` for more details.
-   *
-   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func          Function to be called on every group.
-   * @param stateEncoder  Encoder for the state type.
-   * @param outputEncoder Encoder for the output type.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 2.2.0
-   */
-  def mapGroupsWithState[S, U](
-      func: MapGroupsWithStateFunction[K, V, S, U],
-      stateEncoder: Encoder[S],
-      outputEncoder: Encoder[U]): Dataset[U] = {
-    mapGroupsWithState[S, U](
-      (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s)
-    )(stateEncoder, outputEncoder)
-  }
-
-  /**
-   * (Java-specific)
-   * Applies the given function to each group of data, while maintaining a user-defined per-group
-   * state. The result Dataset will represent the objects returned by the function.
-   * For a static batch Dataset, the function will be invoked once per group. For a streaming
-   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
-   * updates to each group's state will be saved across invocations.
-   * See `GroupState` for more details.
-   *
-   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func          Function to be called on every group.
-   * @param stateEncoder  Encoder for the state type.
-   * @param outputEncoder Encoder for the output type.
-   * @param timeoutConf   Timeout configuration for groups that do not receive data for a while.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 2.2.0
-   */
-  def mapGroupsWithState[S, U](
-      func: MapGroupsWithStateFunction[K, V, S, U],
-      stateEncoder: Encoder[S],
-      outputEncoder: Encoder[U],
-      timeoutConf: GroupStateTimeout): Dataset[U] = {
-    mapGroupsWithState[S, U](timeoutConf)(
-      (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s)
-    )(stateEncoder, outputEncoder)
-  }
-
-  /**
-   * (Java-specific)
-   * Applies the given function to each group of data, while maintaining a user-defined per-group
-   * state. The result Dataset will represent the objects returned by the function.
-   * For a static batch Dataset, the function will be invoked once per group. For a streaming
-   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
-   * updates to each group's state will be saved across invocations.
-   * See `GroupState` for more details.
-   *
-   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func          Function to be called on every group.
-   * @param stateEncoder  Encoder for the state type.
-   * @param outputEncoder Encoder for the output type.
-   * @param timeoutConf   Timeout configuration for groups that do not receive data for a while.
-   * @param initialState The user provided state that will be initialized when the first batch
-   *                     of data is processed in the streaming query. The user defined function
-   *                     will be called on the state data even if there are no other values in
-   *                     the group.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 3.2.0
-   */
-  def mapGroupsWithState[S, U](
-      func: MapGroupsWithStateFunction[K, V, S, U],
-      stateEncoder: Encoder[S],
-      outputEncoder: Encoder[U],
-      timeoutConf: GroupStateTimeout,
-      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
-    mapGroupsWithState[S, U](timeoutConf, initialState)(
-      (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s)
-    )(stateEncoder, outputEncoder)
-  }
-
-  /**
-   * (Scala-specific)
-   * Applies the given function to each group of data, while maintaining a user-defined per-group
-   * state. The result Dataset will represent the objects returned by the function.
-   * For a static batch Dataset, the function will be invoked once per group. For a streaming
-   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
-   * updates to each group's state will be saved across invocations.
-   * See `GroupState` for more details.
-   *
-   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func Function to be called on every group.
-   * @param outputMode The output mode of the function.
-   * @param timeoutConf Timeout configuration for groups that do not receive data for a while.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 2.2.0
-   */
+  /** @inheritdoc */
   def flatMapGroupsWithState[S: Encoder, U: Encoder](
       outputMode: OutputMode,
       timeoutConf: GroupStateTimeout)(
@@ -529,29 +174,7 @@
         child = logicalPlan))
   }
 
-  /**
-   * (Scala-specific)
-   * Applies the given function to each group of data, while maintaining a user-defined per-group
-   * state. The result Dataset will represent the objects returned by the function.
-   * For a static batch Dataset, the function will be invoked once per group. For a streaming
-   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
-   * updates to each group's state will be saved across invocations.
-   * See `GroupState` for more details.
-   *
-   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func Function to be called on every group.
-   * @param outputMode The output mode of the function.
-   * @param timeoutConf Timeout configuration for groups that do not receive data for a while.
-   * @param initialState The user provided state that will be initialized when the first batch
-   *                     of data is processed in the streaming query. The user defined function
-   *                     will be called on the state data even if there are no other values in
-   *                     the group. To covert a Dataset `ds` of type  of type `Dataset[(K, S)]`
-   *                     to a `KeyValueGroupedDataset[K, S]`, use
-   *                     {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 3.2.0
-   */
+  /** @inheritdoc */
   def flatMapGroupsWithState[S: Encoder, U: Encoder](
       outputMode: OutputMode,
       timeoutConf: GroupStateTimeout,
@@ -576,91 +199,7 @@
       ))
   }
 
-  /**
-   * (Java-specific)
-   * Applies the given function to each group of data, while maintaining a user-defined per-group
-   * state. The result Dataset will represent the objects returned by the function.
-   * For a static batch Dataset, the function will be invoked once per group. For a streaming
-   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
-   * updates to each group's state will be saved across invocations.
-   * See `GroupState` for more details.
-   *
-   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func          Function to be called on every group.
-   * @param outputMode    The output mode of the function.
-   * @param stateEncoder  Encoder for the state type.
-   * @param outputEncoder Encoder for the output type.
-   * @param timeoutConf   Timeout configuration for groups that do not receive data for a while.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 2.2.0
-   */
-  def flatMapGroupsWithState[S, U](
-      func: FlatMapGroupsWithStateFunction[K, V, S, U],
-      outputMode: OutputMode,
-      stateEncoder: Encoder[S],
-      outputEncoder: Encoder[U],
-      timeoutConf: GroupStateTimeout): Dataset[U] = {
-    val f = (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s).asScala
-    flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder)
-  }
-
-  /**
-   * (Java-specific)
-   * Applies the given function to each group of data, while maintaining a user-defined per-group
-   * state. The result Dataset will represent the objects returned by the function.
-   * For a static batch Dataset, the function will be invoked once per group. For a streaming
-   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
-   * updates to each group's state will be saved across invocations.
-   * See `GroupState` for more details.
-   *
-   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func          Function to be called on every group.
-   * @param outputMode    The output mode of the function.
-   * @param stateEncoder  Encoder for the state type.
-   * @param outputEncoder Encoder for the output type.
-   * @param timeoutConf   Timeout configuration for groups that do not receive data for a while.
-   * @param initialState The user provided state that will be initialized when the first batch
-   *                     of data is processed in the streaming query. The user defined function
-   *                     will be called on the state data even if there are no other values in
-   *                     the group. To covert a Dataset `ds` of type  of type `Dataset[(K, S)]`
-   *                     to a `KeyValueGroupedDataset[K, S]`, use
-   *                     {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 3.2.0
-   */
-  def flatMapGroupsWithState[S, U](
-      func: FlatMapGroupsWithStateFunction[K, V, S, U],
-      outputMode: OutputMode,
-      stateEncoder: Encoder[S],
-      outputEncoder: Encoder[U],
-      timeoutConf: GroupStateTimeout,
-      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
-    val f = (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s).asScala
-    flatMapGroupsWithState[S, U](
-      outputMode, timeoutConf, initialState)(f)(stateEncoder, outputEncoder)
-  }
-
-  /**
-   * (Scala-specific)
-   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
-   * We allow the user to act on per-group set of input rows along with keyed state and the
-   * user can choose to output/return 0 or more rows.
-   * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
-   * in each trigger and the user's state/state variables will be stored persistently across
-   * invocations.
-   *
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked
-   *                          by the operator.
-   * @param timeMode          The time mode semantics of the stateful processor for timers and TTL.
-   * @param outputMode        The output mode of the stateful processor.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   */
+  /** @inheritdoc */
   private[sql] def transformWithState[U: Encoder](
       statefulProcessor: StatefulProcessor[K, V, U],
       timeMode: TimeMode,
@@ -678,29 +217,7 @@
     )
   }
 
-  /**
-   * (Scala-specific)
-   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
-   * We allow the user to act on per-group set of input rows along with keyed state and the
-   * user can choose to output/return 0 or more rows.
-   * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
-   * in each trigger and the user's state/state variables will be stored persistently across
-   * invocations.
-   *
-   * Downstream operators would use specified eventTimeColumnName to calculate watermark.
-   * Note that TimeMode is set to EventTime to ensure correct flow of watermark.
-   *
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param statefulProcessor   Instance of statefulProcessor whose functions will
-   *                            be invoked by the operator.
-   * @param eventTimeColumnName eventTime column in the output dataset. Any operations after
-   *                            transformWithState will use the new eventTimeColumn. The user
-   *                            needs to ensure that the eventTime for emitted output adheres to
-   *                            the watermark boundary, otherwise streaming query will fail.
-   * @param outputMode          The output mode of the stateful processor.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   */
+  /** @inheritdoc */
   private[sql] def transformWithState[U: Encoder](
       statefulProcessor: StatefulProcessor[K, V, U],
       eventTimeColumnName: String,
@@ -716,81 +233,7 @@
     updateEventTimeColumnAfterTransformWithState(transformWithState, eventTimeColumnName)
   }
 
-  /**
-   * (Java-specific)
-   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
-   * We allow the user to act on per-group set of input rows along with keyed state and the
-   * user can choose to output/return 0 or more rows.
-   * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
-   * in each trigger and the user's state/state variables will be stored persistently across
-   * invocations.
-   *
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the
-   *                          operator.
-   * @param timeMode The time mode semantics of the stateful processor for timers and TTL.
-   * @param outputMode The output mode of the stateful processor.
-   * @param outputEncoder Encoder for the output type.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   */
-  private[sql] def transformWithState[U: Encoder](
-      statefulProcessor: StatefulProcessor[K, V, U],
-      timeMode: TimeMode,
-      outputMode: OutputMode,
-      outputEncoder: Encoder[U]): Dataset[U] = {
-    transformWithState(statefulProcessor, timeMode, outputMode)(outputEncoder)
-  }
-
-  /**
-   * (Java-specific)
-   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
-   * We allow the user to act on per-group set of input rows along with keyed state and the
-   * user can choose to output/return 0 or more rows.
-   *
-   * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
-   * in each trigger and the user's state/state variables will be stored persistently across
-   * invocations.
-   *
-   * Downstream operators would use specified eventTimeColumnName to calculate watermark.
-   * Note that TimeMode is set to EventTime to ensure correct flow of watermark.
-   *
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the
-   *                          operator.
-   * @param eventTimeColumnName eventTime column in the output dataset. Any operations after
-   *                            transformWithState will use the new eventTimeColumn. The user
-   *                            needs to ensure that the eventTime for emitted output adheres to
-   *                            the watermark boundary, otherwise streaming query will fail.
-   * @param outputMode        The output mode of the stateful processor.
-   * @param outputEncoder     Encoder for the output type.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   */
-  private[sql] def transformWithState[U: Encoder](
-      statefulProcessor: StatefulProcessor[K, V, U],
-      eventTimeColumnName: String,
-      outputMode: OutputMode,
-      outputEncoder: Encoder[U]): Dataset[U] = {
-    transformWithState(statefulProcessor, eventTimeColumnName, outputMode)(outputEncoder)
-  }
-
-  /**
-   * (Scala-specific)
-   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
-   * Functions as the function above, but with additional initial state.
-   *
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @tparam S The type of initial state objects. Must be encodable to Spark SQL types.
-   * @param statefulProcessor Instance of statefulProcessor whose functions will
-   *                          be invoked by the operator.
-   * @param timeMode          The time mode semantics of the stateful processor for timers and TTL.
-   * @param outputMode        The output mode of the stateful processor.
-   * @param initialState      User provided initial state that will be used to initiate state for
-   *                          the query in the first batch.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   */
+  /** @inheritdoc */
   private[sql] def transformWithState[U: Encoder, S: Encoder](
       statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
       timeMode: TimeMode,
@@ -812,29 +255,7 @@
     )
   }
 
-  /**
-   * (Scala-specific)
-   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
-   * Functions as the function above, but with additional eventTimeColumnName for output.
-   *
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @tparam S The type of initial state objects. Must be encodable to Spark SQL types.
-   *
-   * Downstream operators would use specified eventTimeColumnName to calculate watermark.
-   * Note that TimeMode is set to EventTime to ensure correct flow of watermark.
-   *
-   * @param statefulProcessor   Instance of statefulProcessor whose functions will
-   *                            be invoked by the operator.
-   * @param eventTimeColumnName eventTime column in the output dataset. Any operations after
-   *                            transformWithState will use the new eventTimeColumn. The user
-   *                            needs to ensure that the eventTime for emitted output adheres to
-   *                            the watermark boundary, otherwise streaming query will fail.
-   * @param outputMode          The output mode of the stateful processor.
-   * @param initialState        User provided initial state that will be used to initiate state for
-   *                            the query in the first batch.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   */
+  /** @inheritdoc */
   private[sql] def transformWithState[U: Encoder, S: Encoder](
       statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
       eventTimeColumnName: String,
@@ -856,71 +277,6 @@
   }
 
   /**
-   * (Java-specific)
-   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
-   * Functions as the function above, but with additional initialStateEncoder for state encoding.
-   *
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @tparam S The type of initial state objects. Must be encodable to Spark SQL types.
-   * @param statefulProcessor   Instance of statefulProcessor whose functions will
-   *                            be invoked by the operator.
-   * @param timeMode            The time mode semantics of the stateful processor for
-   *                            timers and TTL.
-   * @param outputMode          The output mode of the stateful processor.
-   * @param initialState        User provided initial state that will be used to initiate state for
-   *                            the query in the first batch.
-   * @param outputEncoder       Encoder for the output type.
-   * @param initialStateEncoder Encoder for the initial state type.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   */
-  private[sql] def transformWithState[U: Encoder, S: Encoder](
-      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
-      timeMode: TimeMode,
-      outputMode: OutputMode,
-      initialState: KeyValueGroupedDataset[K, S],
-      outputEncoder: Encoder[U],
-      initialStateEncoder: Encoder[S]): Dataset[U] = {
-    transformWithState(statefulProcessor, timeMode,
-      outputMode, initialState)(outputEncoder, initialStateEncoder)
-  }
-
-  /**
-   * (Java-specific)
-   * Invokes methods defined in the stateful processor used in arbitrary state API v2.
-   * Functions as the function above, but with additional eventTimeColumnName for output.
-   *
-   * Downstream operators would use specified eventTimeColumnName to calculate watermark.
-   * Note that TimeMode is set to EventTime to ensure correct flow of watermark.
-   *
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @tparam S The type of initial state objects. Must be encodable to Spark SQL types.
-   * @param statefulProcessor Instance of statefulProcessor whose functions will
-   *                          be invoked by the operator.
-   * @param outputMode        The output mode of the stateful processor.
-   * @param initialState      User provided initial state that will be used to initiate state for
-   *                          the query in the first batch.
-   * @param eventTimeColumnName event column in the output dataset. Any operations after
-   *                            transformWithState will use the new eventTimeColumn. The user
-   *                            needs to ensure that the eventTime for emitted output adheres to
-   *                            the watermark boundary, otherwise streaming query will fail.
-   * @param outputEncoder     Encoder for the output type.
-   * @param initialStateEncoder Encoder for the initial state type.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   */
-  private[sql] def transformWithState[U: Encoder, S: Encoder](
-      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
-      outputMode: OutputMode,
-      initialState: KeyValueGroupedDataset[K, S],
-      eventTimeColumnName: String,
-      outputEncoder: Encoder[U],
-      initialStateEncoder: Encoder[S]): Dataset[U] = {
-    transformWithState(statefulProcessor, eventTimeColumnName,
-      outputMode, initialState)(outputEncoder, initialStateEncoder)
-  }
-
-  /**
    * Creates a new dataset with updated eventTimeColumn after the transformWithState
    * logical node.
    */
@@ -939,35 +295,14 @@
         transformWithStateDataset.logicalPlan)))
   }
 
-  /**
-   * (Scala-specific)
-   * Reduces the elements of each group of data using the specified binary function.
-   * The given function must be commutative and associative or the result may be non-deterministic.
-   *
-   * @since 1.6.0
-   */
+  /** @inheritdoc */
   def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
     val vEncoder = encoderFor[V]
     val aggregator: TypedColumn[V, V] = new ReduceAggregator[V](f)(vEncoder).toColumn
     agg(aggregator)
   }
 
-  /**
-   * (Java-specific)
-   * Reduces the elements of each group of data using the specified binary function.
-   * The given function must be commutative and associative or the result may be non-deterministic.
-   *
-   * @since 1.6.0
-   */
-  def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = {
-    reduceGroups(f.call _)
-  }
-
-  /**
-   * Internal helper function for building typed aggregations that return tuples.  For simplicity
-   * and code reuse, we do this without the help of the type system and then use helper functions
-   * that cast appropriately for the user facing interface.
-   */
+  /** @inheritdoc */
   protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
     val encoders = columns.map(c => encoderFor(c.encoder))
     val namedColumns = columns.map(c => withInputType(c.named, vExprEnc, dataAttributes))
@@ -978,192 +313,12 @@
     new Dataset(execution, ExpressionEncoder.tuple(kExprEnc +: encoders))
   }
 
-  /**
-   * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key
-   * and the result of computing this aggregation over all elements in the group.
-   *
-   * @since 1.6.0
-   */
-  def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] =
-    aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]]
-
-  /**
-   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
-   * and the result of computing these aggregations over all elements in the group.
-   *
-   * @since 1.6.0
-   */
-  def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] =
-    aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]]
-
-  /**
-   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
-   * and the result of computing these aggregations over all elements in the group.
-   *
-   * @since 1.6.0
-   */
-  def agg[U1, U2, U3](
-      col1: TypedColumn[V, U1],
-      col2: TypedColumn[V, U2],
-      col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] =
-    aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]]
-
-  /**
-   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
-   * and the result of computing these aggregations over all elements in the group.
-   *
-   * @since 1.6.0
-   */
-  def agg[U1, U2, U3, U4](
-      col1: TypedColumn[V, U1],
-      col2: TypedColumn[V, U2],
-      col3: TypedColumn[V, U3],
-      col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] =
-    aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]]
-
-  /**
-   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
-   * and the result of computing these aggregations over all elements in the group.
-   *
-   * @since 3.0.0
-   */
-  def agg[U1, U2, U3, U4, U5](
-      col1: TypedColumn[V, U1],
-      col2: TypedColumn[V, U2],
-      col3: TypedColumn[V, U3],
-      col4: TypedColumn[V, U4],
-      col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] =
-    aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5)]]
-
-  /**
-   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
-   * and the result of computing these aggregations over all elements in the group.
-   *
-   * @since 3.0.0
-   */
-  def agg[U1, U2, U3, U4, U5, U6](
-      col1: TypedColumn[V, U1],
-      col2: TypedColumn[V, U2],
-      col3: TypedColumn[V, U3],
-      col4: TypedColumn[V, U4],
-      col5: TypedColumn[V, U5],
-      col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] =
-    aggUntyped(col1, col2, col3, col4, col5, col6)
-      .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6)]]
-
-  /**
-   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
-   * and the result of computing these aggregations over all elements in the group.
-   *
-   * @since 3.0.0
-   */
-  def agg[U1, U2, U3, U4, U5, U6, U7](
-      col1: TypedColumn[V, U1],
-      col2: TypedColumn[V, U2],
-      col3: TypedColumn[V, U3],
-      col4: TypedColumn[V, U4],
-      col5: TypedColumn[V, U5],
-      col6: TypedColumn[V, U6],
-      col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] =
-    aggUntyped(col1, col2, col3, col4, col5, col6, col7)
-      .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7)]]
-
-  /**
-   * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
-   * and the result of computing these aggregations over all elements in the group.
-   *
-   * @since 3.0.0
-   */
-  def agg[U1, U2, U3, U4, U5, U6, U7, U8](
-      col1: TypedColumn[V, U1],
-      col2: TypedColumn[V, U2],
-      col3: TypedColumn[V, U3],
-      col4: TypedColumn[V, U4],
-      col5: TypedColumn[V, U5],
-      col6: TypedColumn[V, U6],
-      col7: TypedColumn[V, U7],
-      col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] =
-    aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8)
-      .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)]]
-
-  /**
-   * Returns a [[Dataset]] that contains a tuple with each key and the number of items present
-   * for that key.
-   *
-   * @since 1.6.0
-   */
-  def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long]()))
-
-  /**
-   * (Scala-specific)
-   * Applies the given function to each cogrouped data.  For each unique group, the function will
-   * be passed the grouping key and 2 iterators containing all elements in the group from
-   * [[Dataset]] `this` and `other`.  The function can return an iterator containing elements of an
-   * arbitrary type which will be returned as a new [[Dataset]].
-   *
-   * @since 1.6.0
-   */
-  def cogroup[U, R : Encoder](
-      other: KeyValueGroupedDataset[K, U])(
-      f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = {
-    implicit val uEncoder = other.vExprEnc
-    Dataset[R](
-      sparkSession,
-      CoGroup(
-        f,
-        this.groupingAttributes,
-        other.groupingAttributes,
-        this.dataAttributes,
-        other.dataAttributes,
-        Seq.empty,
-        Seq.empty,
-        this.logicalPlan,
-        other.logicalPlan))
-  }
-
-  /**
-   * (Java-specific)
-   * Applies the given function to each cogrouped data.  For each unique group, the function will
-   * be passed the grouping key and 2 iterators containing all elements in the group from
-   * [[Dataset]] `this` and `other`.  The function can return an iterator containing elements of an
-   * arbitrary type which will be returned as a new [[Dataset]].
-   *
-   * @since 1.6.0
-   */
-  def cogroup[U, R](
-      other: KeyValueGroupedDataset[K, U],
-      f: CoGroupFunction[K, V, U, R],
-      encoder: Encoder[R]): Dataset[R] = {
-    cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder)
-  }
-
-  /**
-   * (Scala-specific)
-   * Applies the given function to each sorted cogrouped data.  For each unique group, the function
-   * will be passed the grouping key and 2 sorted iterators containing all elements in the group
-   * from [[Dataset]] `this` and `other`.  The function can return an iterator containing elements
-   * of an arbitrary type which will be returned as a new [[Dataset]].
-   *
-   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators
-   * to be sorted according to the given sort expressions. That sorting does not add
-   * computational complexity.
-   *
-   * @see [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
-   * @since 3.4.0
-   */
+  /** @inheritdoc */
   def cogroupSorted[U, R : Encoder](
       other: KeyValueGroupedDataset[K, U])(
       thisSortExprs: Column*)(
       otherSortExprs: Column*)(
       f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = {
-    def toSortOrder(col: Column): SortOrder = col.expr match {
-      case expr: SortOrder => expr
-      case expr: Expression => SortOrder(expr, Ascending)
-    }
-
-    val thisSortOrder: Seq[SortOrder] = thisSortExprs.map(toSortOrder)
-    val otherSortOrder: Seq[SortOrder] = otherSortExprs.map(toSortOrder)
-
     implicit val uEncoder = other.vExprEnc
     Dataset[R](
       sparkSession,
@@ -1173,45 +328,19 @@
         other.groupingAttributes,
         this.dataAttributes,
         other.dataAttributes,
-        thisSortOrder,
-        otherSortOrder,
+        MapGroups.sortOrder(thisSortExprs.map(_.expr)),
+        MapGroups.sortOrder(otherSortExprs.map(_.expr)),
         this.logicalPlan,
         other.logicalPlan))
   }
 
-  /**
-   * (Java-specific)
-   * Applies the given function to each sorted cogrouped data.  For each unique group, the function
-   * will be passed the grouping key and 2 sorted iterators containing all elements in the group
-   * from [[Dataset]] `this` and `other`.  The function can return an iterator containing elements
-   * of an arbitrary type which will be returned as a new [[Dataset]].
-   *
-   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators
-   * to be sorted according to the given sort expressions. That sorting does not add
-   * computational complexity.
-   *
-   * @see [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
-   * @since 3.4.0
-   */
-  def cogroupSorted[U, R](
-      other: KeyValueGroupedDataset[K, U],
-      thisSortExprs: Array[Column],
-      otherSortExprs: Array[Column],
-      f: CoGroupFunction[K, V, U, R],
-      encoder: Encoder[R]): Dataset[R] = {
-    import org.apache.spark.util.ArrayImplicits._
-    cogroupSorted(other)(
-      thisSortExprs.toImmutableArraySeq: _*)(otherSortExprs.toImmutableArraySeq: _*)(
-      (key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder)
-  }
-
   override def toString: String = {
     val builder = new StringBuilder
-    val kFields = kExprEnc.schema.map {
-      case f => s"${f.name}: ${f.dataType.simpleString(2)}"
+    val kFields = kExprEnc.schema.map { f =>
+      s"${f.name}: ${f.dataType.simpleString(2)}"
     }
-    val vFields = vExprEnc.schema.map {
-      case f => s"${f.name}: ${f.dataType.simpleString(2)}"
+    val vFields = vExprEnc.schema.map { f =>
+      s"${f.name}: ${f.dataType.simpleString(2)}"
     }
     builder.append("KeyValueGroupedDataset: [key: [")
     builder.append(kFields.take(2).mkString(", "))
@@ -1225,4 +354,221 @@
     }
     builder.append("]]").toString()
   }
+
+  ////////////////////////////////////////////////////////////////////////////
+  // Return type overrides to make sure we return the implementation instead
+  // of the interface.
+  ////////////////////////////////////////////////////////////////////////////
+  /** @inheritdoc */
+  override def mapValues[W](
+      func: MapFunction[V, W],
+      encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = super.mapValues(func, encoder)
+
+  /** @inheritdoc */
+  override def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] =
+    super.flatMapGroups(f)
+
+  /** @inheritdoc */
+  override def flatMapGroups[U](
+      f: FlatMapGroupsFunction[K, V, U],
+      encoder: Encoder[U]): Dataset[U] = super.flatMapGroups(f, encoder)
+
+  /** @inheritdoc */
+  override def flatMapSortedGroups[U](
+      SortExprs: Array[Column],
+      f: FlatMapGroupsFunction[K, V, U],
+      encoder: Encoder[U]): Dataset[U] = super.flatMapSortedGroups(SortExprs, f, encoder)
+
+  /** @inheritdoc */
+  override def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = super.mapGroups(f)
+
+  /** @inheritdoc */
+  override def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] =
+    super.mapGroups(f, encoder)
+
+  /** @inheritdoc */
+  override def mapGroupsWithState[S, U](
+      func: MapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U]): Dataset[U] =
+    super.mapGroupsWithState(func, stateEncoder, outputEncoder)
+
+  /** @inheritdoc */
+  override def mapGroupsWithState[S, U](
+      func: MapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U],
+      timeoutConf: GroupStateTimeout): Dataset[U] =
+    super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf)
+
+  /** @inheritdoc */
+  override def mapGroupsWithState[S, U](
+      func: MapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U],
+      timeoutConf: GroupStateTimeout,
+      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] =
+    super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf, initialState)
+
+  /** @inheritdoc */
+  override def flatMapGroupsWithState[S, U](
+      func: FlatMapGroupsWithStateFunction[K, V, S, U],
+      outputMode: OutputMode,
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U],
+      timeoutConf: GroupStateTimeout): Dataset[U] =
+    super.flatMapGroupsWithState(func, outputMode, stateEncoder, outputEncoder, timeoutConf)
+
+  /** @inheritdoc */
+  override def flatMapGroupsWithState[S, U](
+      func: FlatMapGroupsWithStateFunction[K, V, S, U],
+      outputMode: OutputMode,
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U],
+      timeoutConf: GroupStateTimeout,
+      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] =
+    super.flatMapGroupsWithState(
+      func,
+      outputMode,
+      stateEncoder,
+      outputEncoder,
+      timeoutConf,
+      initialState)
+
+  /** @inheritdoc */
+  override private[sql] def transformWithState[U: Encoder](
+      statefulProcessor: StatefulProcessor[K, V, U],
+      timeMode: TimeMode,
+      outputMode: OutputMode,
+      outputEncoder: Encoder[U]) =
+    super.transformWithState(statefulProcessor, timeMode, outputMode, outputEncoder)
+
+  /** @inheritdoc */
+  override private[sql] def transformWithState[U: Encoder](
+      statefulProcessor: StatefulProcessor[K, V, U],
+      eventTimeColumnName: String,
+      outputMode: OutputMode,
+      outputEncoder: Encoder[U]) =
+    super.transformWithState(statefulProcessor, eventTimeColumnName, outputMode, outputEncoder)
+
+  /** @inheritdoc */
+  override private[sql] def transformWithState[U: Encoder, S: Encoder](
+      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+      timeMode: TimeMode,
+      outputMode: OutputMode,
+      initialState: KeyValueGroupedDataset[K, S],
+      outputEncoder: Encoder[U],
+      initialStateEncoder: Encoder[S]) = super.transformWithState(
+    statefulProcessor,
+    timeMode,
+    outputMode,
+    initialState,
+    outputEncoder,
+    initialStateEncoder)
+
+  /** @inheritdoc */
+  override private[sql] def transformWithState[U: Encoder, S: Encoder](
+      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+      outputMode: OutputMode,
+      initialState: KeyValueGroupedDataset[K, S],
+      eventTimeColumnName: String,
+      outputEncoder: Encoder[U],
+      initialStateEncoder: Encoder[S]) = super.transformWithState(
+    statefulProcessor,
+    outputMode,
+    initialState,
+    eventTimeColumnName,
+    outputEncoder,
+    initialStateEncoder)
+
+  /** @inheritdoc */
+  override def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = super.reduceGroups(f)
+
+  /** @inheritdoc */
+  override def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = super.agg(col1)
+
+  /** @inheritdoc */
+  override def agg[U1, U2](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = super.agg(col1, col2)
+
+  /** @inheritdoc */
+  override def agg[U1, U2, U3](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = super.agg(col1, col2, col3)
+
+  /** @inheritdoc */
+  override def agg[U1, U2, U3, U4](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = super.agg(col1, col2, col3, col4)
+
+  /** @inheritdoc */
+  override def agg[U1, U2, U3, U4, U5](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] =
+    super.agg(col1, col2, col3, col4, col5)
+
+  /** @inheritdoc */
+  override def agg[U1, U2, U3, U4, U5, U6](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5],
+      col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] =
+    super.agg(col1, col2, col3, col4, col5, col6)
+
+  /** @inheritdoc */
+  override def agg[U1, U2, U3, U4, U5, U6, U7](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5],
+      col6: TypedColumn[V, U6],
+      col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] =
+    super.agg(col1, col2, col3, col4, col5, col6, col7)
+
+  /** @inheritdoc */
+  override def agg[U1, U2, U3, U4, U5, U6, U7, U8](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5],
+      col6: TypedColumn[V, U6],
+      col7: TypedColumn[V, U7],
+      col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] =
+    super.agg(col1, col2, col3, col4, col5, col6, col7, col8)
+
+  /** @inheritdoc */
+  override def count(): Dataset[(K, Long)] = super.count()
+
+  /** @inheritdoc */
+  override def cogroup[U, R: Encoder](
+      other: KeyValueGroupedDataset[K, U])(
+      f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] =
+    super.cogroup(other)(f)
+
+  /** @inheritdoc */
+  override def cogroup[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] =
+    super.cogroup(other, f, encoder)
+
+  /** @inheritdoc */
+  override def cogroupSorted[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      thisSortExprs: Array[Column],
+      otherSortExprs: Array[Column],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] =
+    super.cogroupSorted(other, thisSortExprs, otherSortExprs, f, encoder)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 777baa3..4e44540 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -117,13 +117,7 @@
     columnExprs.map(column)
   }
 
-
-  /**
-   * Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions
-   * of current `RelationalGroupedDataset`.
-   *
-   * @since 3.0.0
-   */
+  /** @inheritdoc */
   def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = {
     val keyEncoder = encoderFor[K]
     val valueEncoder = encoderFor[T]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 808f783..be91f5e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -118,7 +118,7 @@
     sparkContext.listenerBus.waitUntilEmpty()
     assert(metrics.length == 2)
 
-    assert(metrics(0)._1 == "foreach")
+    assert(metrics(0)._1 == "foreachPartition")
     assert(metrics(1)._1 == "reduce")
 
     spark.listenerManager.unregister(listener)