blob: 0811b1b030cdde58918ae88e2a7e995cfb1ae61f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.api.ml
import org.apache.spark.sql.functions.udf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.SparkContext
import org.apache.sysml.runtime.matrix.data.MatrixBlock
import org.apache.sysml.runtime.DMLRuntimeException
import org.apache.sysml.runtime.matrix.MatrixCharacteristics
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils
import org.apache.sysml.api.mlcontext.MLResults
import org.apache.sysml.api.mlcontext.ScriptFactory._
import org.apache.sysml.api.mlcontext.Script
import org.apache.sysml.api.mlcontext.BinaryBlockMatrix
object PredictionUtils {
def getGLMPredictionScript(B_full: BinaryBlockMatrix, isSingleNode:Boolean, dfam:java.lang.Integer=1): (Script, String) = {
val script = dml(ScriptsUtils.getDMLScript(LogisticRegressionModel.scriptPath))
.in("$X", " ")
.in("$B", " ")
.in("$dfam", dfam)
.out("means")
val ret = if(isSingleNode) {
script.in("B_full", B_full.getMatrixBlock, B_full.getMatrixMetadata)
}
else {
script.in("B_full", B_full)
}
(ret, "X")
}
def fillLabelMapping(df: ScriptsUtils.SparkDataType, revLabelMapping: java.util.HashMap[Int, String]): RDD[String] = {
val temp = df.select("label").distinct.rdd.map(_.apply(0).toString).collect()
val labelMapping = new java.util.HashMap[String, Int]
for(i <- 0 until temp.length) {
labelMapping.put(temp(i), i+1)
revLabelMapping.put(i+1, temp(i))
}
df.select("label").rdd.map( x => labelMapping.get(x.apply(0).toString).toString )
}
def fillLabelMapping(y_mb: MatrixBlock, revLabelMapping: java.util.HashMap[Int, String]): Unit = {
val labelMapping = new java.util.HashMap[String, Int]
if(y_mb.getNumColumns != 1) {
throw new RuntimeException("Expected a column vector for y")
}
if(y_mb.isInSparseFormat()) {
throw new DMLRuntimeException("Sparse block is not implemented for fit")
}
else {
val denseBlock = y_mb.getDenseBlock()
var id:Int = 1
for(i <- 0 until denseBlock.length) {
val v = denseBlock(i).toString()
if(!labelMapping.containsKey(v)) {
labelMapping.put(v, id)
revLabelMapping.put(id, v)
id += 1
}
denseBlock.update(i, labelMapping.get(v))
}
}
}
class LabelMappingData(val labelMapping: java.util.HashMap[Int, String]) extends Serializable {
def mapLabelStr(x:Double):String = {
if(labelMapping.containsKey(x.toInt))
labelMapping.get(x.toInt)
else
throw new RuntimeException("Incorrect label mapping")
}
def mapLabelDouble(x:Double):Double = {
if(labelMapping.containsKey(x.toInt))
labelMapping.get(x.toInt).toDouble
else
throw new RuntimeException("Incorrect label mapping")
}
val mapLabel_udf = {
try {
val it = labelMapping.values().iterator()
while(it.hasNext()) {
it.next().toDouble
}
udf(mapLabelDouble _)
} catch {
case e: Exception => udf(mapLabelStr _)
}
}
}
def updateLabels(isSingleNode:Boolean, df:DataFrame, X: MatrixBlock, labelColName:String, labelMapping: java.util.HashMap[Int, String]): DataFrame = {
if(isSingleNode) {
if(X.isInSparseFormat()) {
throw new RuntimeException("Since predicted label is a column vector, expected it to be in dense format")
}
for(i <- 0 until X.getNumRows) {
val v:Int = X.getValue(i, 0).toInt
if(labelMapping.containsKey(v)) {
X.setValue(i, 0, labelMapping.get(v).toDouble)
}
else {
throw new RuntimeException("No mapping found for " + v + " in " + labelMapping.toString())
}
}
return null
}
else {
val serObj = new LabelMappingData(labelMapping)
return df.withColumn(labelColName, serObj.mapLabel_udf(df(labelColName)))
.withColumnRenamed(labelColName, "prediction")
}
}
def joinUsingID(df1:DataFrame, df2:DataFrame):DataFrame = {
df1.join(df2, RDDConverterUtils.DF_ID_COLUMN)
}
def computePredictedClassLabelsFromProbability(mlscoreoutput:MLResults, isSingleNode:Boolean, sc:SparkContext, inProbVar:String): MLResults = {
val ml = new org.apache.sysml.api.mlcontext.MLContext(sc)
val script = dml(
"""
Prob = read("temp1");
Prediction = rowIndexMax(Prob); # assuming one-based label mapping
write(Prediction, "tempOut", "csv");
""").out("Prediction")
val probVar = mlscoreoutput.getBinaryBlockMatrix(inProbVar)
if(isSingleNode) {
ml.execute(script.in("Prob", probVar.getMatrixBlock, probVar.getMatrixMetadata))
}
else {
ml.execute(script.in("Prob", probVar))
}
}
}