[SPARK-23046][ML][SPARKR] Have RFormula include VectorSizeHint in pipeline
## What changes were proposed in this pull request?
Including VectorSizeHint in RFormula piplelines will allow them to be applied to streaming dataframes.
## How was this patch tested?
Unit tests.
Author: Bago Amirbekian <bago@databricks.com>
Closes #20238 from MrBago/rFormulaVectorSize.
(cherry picked from commit 186bf8fb2e9ff8a80f3f6bcb5f2a0327fa79a1c9)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R
index a53c92c..23dda42 100644
--- a/R/pkg/R/mllib_utils.R
+++ b/R/pkg/R/mllib_utils.R
@@ -130,3 +130,4 @@
stop("Unsupported model: ", jobj)
}
}
+
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 7da3339..f384ffbf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -25,7 +25,7 @@
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
import org.apache.spark.ml.attribute.AttributeGroup
-import org.apache.spark.ml.linalg.VectorUDT
+import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol}
import org.apache.spark.ml.util._
@@ -210,8 +210,8 @@
// First we index each string column referenced by the input terms.
val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term =>
- dataset.schema(term) match {
- case column if column.dataType == StringType =>
+ dataset.schema(term).dataType match {
+ case _: StringType =>
val indexCol = tmpColumn("stridx")
encoderStages += new StringIndexer()
.setInputCol(term)
@@ -220,6 +220,18 @@
.setHandleInvalid($(handleInvalid))
prefixesToRewrite(indexCol + "_") = term + "_"
(term, indexCol)
+ case _: VectorUDT =>
+ val group = AttributeGroup.fromStructField(dataset.schema(term))
+ val size = if (group.size < 0) {
+ dataset.select(term).first().getAs[Vector](0).size
+ } else {
+ group.size
+ }
+ encoderStages += new VectorSizeHint(uid)
+ .setHandleInvalid("optimistic")
+ .setInputCol(term)
+ .setSize(size)
+ (term, term)
case _ =>
(term, term)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 5d09c90..f3f4b5a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -17,15 +17,15 @@
package org.apache.spark.ml.feature
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.SparkException
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.sql.{DataFrame, Encoder, Row}
import org.apache.spark.sql.types.DoubleType
-class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class RFormulaSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -548,4 +548,31 @@
assert(result3.collect() === expected3.collect())
assert(result4.collect() === expected4.collect())
}
+
+ test("Use Vectors as inputs to formula.") {
+ val original = Seq(
+ (1, 4, Vectors.dense(0.0, 0.0, 4.0)),
+ (2, 4, Vectors.dense(1.0, 0.0, 4.0)),
+ (3, 5, Vectors.dense(1.0, 0.0, 5.0)),
+ (4, 5, Vectors.dense(0.0, 1.0, 5.0))
+ ).toDF("id", "a", "b")
+ val formula = new RFormula().setFormula("id ~ a + b")
+ val (first +: rest) = Seq("id", "a", "b", "features", "label")
+ testTransformer[(Int, Int, Vector)](original, formula.fit(original), first, rest: _*) {
+ case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) =>
+ assert(label === id)
+ assert(features.toArray === a +: b.toArray)
+ }
+
+ val group = new AttributeGroup("b", 3)
+ val vectorColWithMetadata = original("b").as("b", group.toMetadata())
+ val dfWithMetadata = original.withColumn("b", vectorColWithMetadata)
+ val model = formula.fit(dfWithMetadata)
+ // model should work even when applied to dataframe without metadata.
+ testTransformer[(Int, Int, Vector)](original, model, first, rest: _*) {
+ case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) =>
+ assert(label === id)
+ assert(features.toArray === a +: b.toArray)
+ }
+ }
}