blob: 06fedc9188e36b3187f101f9c5a3c3da10ceb1b7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.api;
import java.io.IOException;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SQLContext.QueryExecution;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.parser.ParseException;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.spark.functions.GetMIMBFromRow;
import org.apache.sysml.runtime.instructions.spark.functions.GetMLBlock;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
/**
* Experimental API: Might be discontinued in future release
*
* This class serves four purposes:
* 1. It allows SystemML to fit nicely in MLPipeline by reducing number of reblocks.
* 2. It allows users to easily read and write matrices without worrying
* too much about format, metadata and type of underlying RDDs.
* 3. It provides mechanism to convert to and from MLLib's BlockedMatrix format
* 4. It provides off-the-shelf library for Distributed Blocked Matrix and reduces learning curve for using SystemML.
* However, it is important to know that it is easy to abuse this off-the-shelf library and think it as replacement
* to writing DML, which it is not. It does not provide any optimization between calls. A simple example
* of the optimization that is conveniently skipped is: (t(m) %*% m)).
* Also, note that this library is not thread-safe. The operator precedence is not exactly same as DML (as the precedence is
* enforced by scala compiler), so please use appropriate brackets to enforce precedence.
import org.apache.sysml.api.{MLContext, MLMatrix}
val ml = new MLContext(sc)
val mat1 = ml.read(sqlContext, "V_small.csv", "csv")
val mat2 = ml.read(sqlContext, "W_small.mtx", "binary")
val result = mat1.transpose() %*% mat2
result.write("Result_small.mtx", "text")
*/
public class MLMatrix extends DataFrame {
private static final long serialVersionUID = -7005940673916671165L;
protected static final Log LOG = LogFactory.getLog(DMLScript.class.getName());
protected MatrixCharacteristics mc = null;
protected MLContext ml = null;
protected MLMatrix(SQLContext sqlContext, LogicalPlan logicalPlan, MLContext ml) {
super(sqlContext, logicalPlan);
this.ml = ml;
}
protected MLMatrix(SQLContext sqlContext, QueryExecution queryExecution, MLContext ml) {
super(sqlContext, queryExecution);
this.ml = ml;
}
// Only used internally to set a new MLMatrix after one of matrix operations.
// Not to be used externally.
protected MLMatrix(DataFrame df, MatrixCharacteristics mc, MLContext ml) throws DMLRuntimeException {
super(df.sqlContext(), df.logicalPlan());
this.mc = mc;
this.ml = ml;
}
//TODO replace default blocksize
static String writeStmt = "write(output, \"tmp\", format=\"binary\", rows_in_block=" + OptimizerUtils.DEFAULT_BLOCKSIZE + ", cols_in_block=" + OptimizerUtils.DEFAULT_BLOCKSIZE + ");";
// ------------------------------------------------------------------------------------------------
// /**
// * Experimental unstable API: Converts our blocked matrix format to MLLib's format
// * @return
// */
// public BlockMatrix toBlockedMatrix() {
// JavaPairRDD<MatrixIndexes, MatrixBlock> blocks = getRDDLazily(this);
// RDD<Tuple2<Tuple2<Object, Object>, Matrix>> mllibBlocks = blocks.mapToPair(new GetMLLibBlocks(mc.getRows(), mc.getCols(), mc.getRowsPerBlock(), mc.getColsPerBlock())).rdd();
// return new BlockMatrix(mllibBlocks, mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getRows(), mc.getCols());
// }
// ------------------------------------------------------------------------------------------------
static MLMatrix createMLMatrix(MLContext ml, SQLContext sqlContext, JavaPairRDD<MatrixIndexes, MatrixBlock> blocks, MatrixCharacteristics mc) throws DMLRuntimeException {
RDD<Row> rows = blocks.map(new GetMLBlock()).rdd();
StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
return new MLMatrix(sqlContext.createDataFrame(rows.toJavaRDD(), schema), mc, ml);
}
/**
* Convenient method to write a MLMatrix.
*/
public void write(String filePath, String format) throws IOException, DMLException {
ml.reset();
ml.registerInput("left", this);
ml.executeScript("left = read(\"\"); output=left; write(output, \"" + filePath + "\", format=\"" + format + "\");");
}
private double getScalarBuiltinFunctionResult(String fn) throws IOException, DMLException {
if(fn.equals("nrow") || fn.equals("ncol")) {
ml.reset();
ml.registerInput("left", getRDDLazily(this), mc.getRows(), mc.getCols(), mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros());
ml.registerOutput("output");
String script = "left = read(\"\");"
+ "val = " + fn + "(left); "
+ "output = matrix(val, rows=1, cols=1); "
+ writeStmt;
MLOutput out = ml.executeScript(script);
List<Tuple2<MatrixIndexes, MatrixBlock>> result = out.getBinaryBlockedRDD("output").collect();
if(result == null || result.size() != 1) {
throw new DMLRuntimeException("Error while computing the function: " + fn);
}
return result.get(0)._2.getValue(0, 0);
}
else {
throw new DMLRuntimeException("The function " + fn + " is not yet supported in MLMatrix");
}
}
/**
* Gets or computes the number of rows.
* @return
* @throws ParseException
* @throws DMLException
* @throws IOException
*/
public long numRows() throws IOException, DMLException {
if(mc.rowsKnown()) {
return mc.getRows();
}
else {
return (long) getScalarBuiltinFunctionResult("nrow");
}
}
/**
* Gets or computes the number of columns.
* @return
* @throws ParseException
* @throws DMLException
* @throws IOException
*/
public long numCols() throws IOException, DMLException {
if(mc.colsKnown()) {
return mc.getCols();
}
else {
return (long) getScalarBuiltinFunctionResult("ncol");
}
}
public int rowsPerBlock() {
return mc.getRowsPerBlock();
}
public int colsPerBlock() {
return mc.getColsPerBlock();
}
private String getScript(String binaryOperator) {
return "left = read(\"\");"
+ "right = read(\"\");"
+ "output = left " + binaryOperator + " right; "
+ writeStmt;
}
private String getScalarBinaryScript(String binaryOperator, double scalar, boolean isScalarLeft) {
if(isScalarLeft) {
return "left = read(\"\");"
+ "output = " + scalar + " " + binaryOperator + " left ;"
+ writeStmt;
}
else {
return "left = read(\"\");"
+ "output = left " + binaryOperator + " " + scalar + ";"
+ writeStmt;
}
}
static JavaPairRDD<MatrixIndexes, MatrixBlock> getRDDLazily(MLMatrix mat) {
return mat.rdd().toJavaRDD().mapToPair(new GetMIMBFromRow());
}
private MLMatrix matrixBinaryOp(MLMatrix that, String op) throws IOException, DMLException {
if(mc.getRowsPerBlock() != that.mc.getRowsPerBlock() || mc.getColsPerBlock() != that.mc.getColsPerBlock()) {
throw new DMLRuntimeException("Incompatible block sizes: brlen:" + mc.getRowsPerBlock() + "!=" + that.mc.getRowsPerBlock() + " || bclen:" + mc.getColsPerBlock() + "!=" + that.mc.getColsPerBlock());
}
if(op.equals("%*%")) {
if(mc.getCols() != that.mc.getRows()) {
throw new DMLRuntimeException("Dimensions mismatch:" + mc.getCols() + "!=" + that.mc.getRows());
}
}
else {
if(mc.getRows() != that.mc.getRows() || mc.getCols() != that.mc.getCols()) {
throw new DMLRuntimeException("Dimensions mismatch:" + mc.getRows() + "!=" + that.mc.getRows() + " || " + mc.getCols() + "!=" + that.mc.getCols());
}
}
ml.reset();
ml.registerInput("left", this);
ml.registerInput("right", that);
ml.registerOutput("output");
MLOutput out = ml.executeScript(getScript(op));
RDD<Row> rows = out.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
return new MLMatrix(this.sqlContext().createDataFrame(rows.toJavaRDD(), schema), mcOut, ml);
}
private MLMatrix scalarBinaryOp(Double scalar, String op, boolean isScalarLeft) throws IOException, DMLException {
ml.reset();
ml.registerInput("left", this);
ml.registerOutput("output");
MLOutput out = ml.executeScript(getScalarBinaryScript(op, scalar, isScalarLeft));
RDD<Row> rows = out.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
return new MLMatrix(this.sqlContext().createDataFrame(rows.toJavaRDD(), schema), mcOut, ml);
}
// ---------------------------------------------------
// Simple operator loading but doesnot utilize the optimizer
public MLMatrix $greater(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, ">");
}
public MLMatrix $less(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "<");
}
public MLMatrix $greater$eq(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, ">=");
}
public MLMatrix $less$eq(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "<=");
}
public MLMatrix $eq$eq(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "==");
}
public MLMatrix $bang$eq(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "!=");
}
public MLMatrix $up(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "^");
}
public MLMatrix exp(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "^");
}
public MLMatrix $plus(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "+");
}
public MLMatrix add(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "+");
}
public MLMatrix $minus(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "-");
}
public MLMatrix minus(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "-");
}
public MLMatrix $times(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "*");
}
public MLMatrix elementWiseMultiply(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "*");
}
public MLMatrix $div(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "/");
}
public MLMatrix divide(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "/");
}
public MLMatrix $percent$div$percent(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "%/%");
}
public MLMatrix integerDivision(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "%/%");
}
public MLMatrix $percent$percent(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "%%");
}
public MLMatrix modulus(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "%%");
}
public MLMatrix $percent$times$percent(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "%*%");
}
public MLMatrix multiply(MLMatrix that) throws IOException, DMLException {
return matrixBinaryOp(that, "%*%");
}
public MLMatrix transpose() throws IOException, DMLException {
ml.reset();
ml.registerInput("left", this);
ml.registerOutput("output");
String script = "left = read(\"\");"
+ "output = t(left); "
+ writeStmt;
MLOutput out = ml.executeScript(script);
RDD<Row> rows = out.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
return new MLMatrix(this.sqlContext().createDataFrame(rows.toJavaRDD(), schema), mcOut, ml);
}
// TODO: For 'scalar op matrix' operations: Do implicit conversions
public MLMatrix $plus(Double scalar) throws IOException, DMLException {
return scalarBinaryOp(scalar, "+", false);
}
public MLMatrix add(Double scalar) throws IOException, DMLException {
return scalarBinaryOp(scalar, "+", false);
}
public MLMatrix $minus(Double scalar) throws IOException, DMLException {
return scalarBinaryOp(scalar, "-", false);
}
public MLMatrix minus(Double scalar) throws IOException, DMLException {
return scalarBinaryOp(scalar, "-", false);
}
public MLMatrix $times(Double scalar) throws IOException, DMLException {
return scalarBinaryOp(scalar, "*", false);
}
public MLMatrix elementWiseMultiply(Double scalar) throws IOException, DMLException {
return scalarBinaryOp(scalar, "*", false);
}
public MLMatrix $div(Double scalar) throws IOException, DMLException {
return scalarBinaryOp(scalar, "/", false);
}
public MLMatrix divide(Double scalar) throws IOException, DMLException {
return scalarBinaryOp(scalar, "/", false);
}
public MLMatrix $greater(Double scalar) throws IOException, DMLException {
return scalarBinaryOp(scalar, ">", false);
}
public MLMatrix $less(Double scalar) throws IOException, DMLException {
return scalarBinaryOp(scalar, "<", false);
}
public MLMatrix $greater$eq(Double scalar) throws IOException, DMLException {
return scalarBinaryOp(scalar, ">=", false);
}
public MLMatrix $less$eq(Double scalar) throws IOException, DMLException {
return scalarBinaryOp(scalar, "<=", false);
}
public MLMatrix $eq$eq(Double scalar) throws IOException, DMLException {
return scalarBinaryOp(scalar, "==", false);
}
public MLMatrix $bang$eq(Double scalar) throws IOException, DMLException {
return scalarBinaryOp(scalar, "!=", false);
}
}