blob: d878a2aae8ea7c3813fdd5b061eee99068946c4e [file] [log] [blame]
package io.pivotal.gemfire.spark.connector.internal.oql
import com.gemstone.gemfire.cache.query.internal.StructImpl
import org.apache.spark.sql.types._
import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
class SchemaBuilder[T](queryRDD: QueryRDD[T]) extends Logging {
val nullStructType = StructType(Nil)
val typeMap:Map[Class[_], DataType] = Map(
(classOf[java.lang.String], StringType),
(classOf[java.lang.Integer], IntegerType),
(classOf[java.lang.Short], ShortType),
(classOf[java.lang.Long], LongType),
(classOf[java.lang.Double], DoubleType),
(classOf[java.lang.Float], FloatType),
(classOf[java.lang.Boolean], BooleanType),
(classOf[java.lang.Byte], ByteType),
(classOf[java.util.Date], DateType),
(classOf[java.lang.Object], nullStructType)
)
/**
* Analyse QueryRDD to get the Spark schema
* @return The schema represented by Spark StructType
*/
def toSparkSchema(): StructType = {
val row = queryRDD.first()
val tpe = row match {
case r: StructImpl => constructFromStruct(r)
case null => StructType(StructField("col1", NullType) :: Nil)
case default =>
val value = typeMap.getOrElse(default.getClass(), nullStructType)
StructType(StructField("col1", value) :: Nil)
}
logInfo(s"Schema: $tpe")
tpe
}
def constructFromStruct(r:StructImpl) = {
val names = r.getFieldNames
val values = r.getFieldValues
val lb = new ListBuffer[StructField]()
for (i <- 0 until names.length) {
val name = names(i)
val value = values(i)
val dataType = value match {
case null => NullType
case default => typeMap.getOrElse(default.getClass, nullStructType)
}
lb += StructField(name, dataType)
}
StructType(lb.toSeq)
}
}