[SYSTEMDS-2571] Handle naStrings in csv frame read and meta data
- add new test for naStrings argument and change Beginning to change
Data Expression to handle this.
- List Variable passed to Nary
- add Nary List test for printing lists in DML
- Improve error message in StatementBlock to reflect that frames or
tensors also can be used in a print statement
- only warnings in log4j
- enable nan finding
- removed debugging output
- fix RandSP Instruction
- add nan test files to github
- List now can contain any type in any position
- time test fix missing option of zero argument parsing of function
- Fix DML Translator (bug introduced in last commit) to assign first
expr
- Nan values in frames
- Read NAStrings from mtd file
Closes #990.
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java
index 30cb9e9..7fad1ac 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -167,7 +167,7 @@
* @param conf Hadoop configuration
* @param args arguments
* @return true if success, false otherwise
- * @throws IOException If an internal IO Exception happened.
+ * @throws IOException If an internal IOException happens.
*/
public static boolean executeScript( Configuration conf, String[] args )
throws IOException, ParseException, DMLScriptException
diff --git a/src/main/java/org/apache/sysds/api/mlcontext/MLContext.java b/src/main/java/org/apache/sysds/api/mlcontext/MLContext.java
index 6a39fda..838f8c7 100644
--- a/src/main/java/org/apache/sysds/api/mlcontext/MLContext.java
+++ b/src/main/java/org/apache/sysds/api/mlcontext/MLContext.java
@@ -265,7 +265,7 @@
if (activeMLContext == null) {
System.out.println(MLContextUtil.welcomeMessage());
}
-
+
this.spark = spark;
DMLScript.setGlobalExecMode(executionType.getExecMode());
diff --git a/src/main/java/org/apache/sysds/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysds/api/mlcontext/MLContextConversionUtil.java
index 802059f..3f52929 100644
--- a/src/main/java/org/apache/sysds/api/mlcontext/MLContextConversionUtil.java
+++ b/src/main/java/org/apache/sysds/api/mlcontext/MLContextConversionUtil.java
@@ -19,6 +19,13 @@
package org.apache.sysds.api.mlcontext;
+import java.io.InputStream;
+import java.net.URL;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
import org.apache.commons.io.IOUtils;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
@@ -58,16 +65,10 @@
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
+
import scala.collection.JavaConversions;
import scala.reflect.ClassTag;
-import java.io.InputStream;
-import java.net.URL;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Iterator;
-import java.util.List;
-
/**
* Utility class containing methods to perform data conversions.
*
@@ -623,13 +624,10 @@
new MetaDataFormat(mc, FileFormat.BINARY),
frameMetadata.getFrameSchema().getSchema().toArray(new ValueType[0]));
JavaPairRDD<Long, FrameBlock> rdd;
- try {
- rdd = FrameRDDConverterUtils.csvToBinaryBlock(jsc(), javaPairRDDText, mc, frameObject.getSchema(), false,
- ",", false, -1);
- } catch (DMLRuntimeException e) {
- e.printStackTrace();
- return null;
- }
+
+ rdd = FrameRDDConverterUtils.csvToBinaryBlock(jsc(), javaPairRDDText, mc, frameObject.getSchema(), false,
+ ",", false, -1, UtilFunctions.defaultNaString);
+
frameObject.setRDDHandle(new RDDObject(rdd));
return frameObject;
}
diff --git a/src/main/java/org/apache/sysds/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysds/api/mlcontext/MLContextUtil.java
index 4f994e9..eeb479a 100644
--- a/src/main/java/org/apache/sysds/api/mlcontext/MLContextUtil.java
+++ b/src/main/java/org/apache/sysds/api/mlcontext/MLContextUtil.java
@@ -967,7 +967,7 @@
}
return sb.toString();
}
-
+
/**
* Obtain the Spark Context
*
diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java b/src/main/java/org/apache/sysds/hops/DataOp.java
index 99cf91e..0046078 100644
--- a/src/main/java/org/apache/sysds/hops/DataOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataOp.java
@@ -19,6 +19,11 @@
package org.apache.sysds.hops;
+import java.util.HashMap;
+import java.util.Map.Entry;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.OpOpData;
@@ -36,15 +41,12 @@
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.LocalFileUtils;
-import java.util.HashMap;
-import java.util.Map.Entry;
-
/**
* A DataOp can be either a persistent read/write or transient read/write - writes will always have at least one input,
* but all types can have parameters (e.g., for csv literals of delimiter, header, etc).
*/
-public class DataOp extends Hop
-{
+public class DataOp extends Hop {
+ private static final Log LOG = LogFactory.getLog(DataOp.class.getName());
private OpOpData _op;
private String _fileName = null;
@@ -123,6 +125,9 @@
String s = e.getKey();
Hop input = e.getValue();
getInput().add(input);
+ if (LOG.isDebugEnabled()){
+ LOG.debug(String.format("%15s - %s",s,input));
+ }
input.getParent().add(this);
_paramIndexMap.put(s, index);
diff --git a/src/main/java/org/apache/sysds/hops/DnnOp.java b/src/main/java/org/apache/sysds/hops/DnnOp.java
index e4eed0e..54978f1 100644
--- a/src/main/java/org/apache/sysds/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysds/hops/DnnOp.java
@@ -19,6 +19,10 @@
package org.apache.sysds.hops;
+import java.util.ArrayList;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOp2;
@@ -34,10 +38,9 @@
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
-import java.util.ArrayList;
+public class DnnOp extends MultiThreadedHop {
+ private static final Log LOG = LogFactory.getLog(DnnOp.class.getName());
-public class DnnOp extends MultiThreadedHop
-{
// -------------------------------------------------------------------------
// This flag allows us to compile plans with less unknowns and also serves as future tensorblock integration.
// By default, these flags are turned on.
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index 24aade1..2558c1d 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -58,9 +58,8 @@
import java.util.HashMap;
import java.util.HashSet;
-public abstract class Hop implements ParseInfo
-{
- protected static final Log LOG = LogFactory.getLog(Hop.class.getName());
+public abstract class Hop implements ParseInfo {
+ private static final Log LOG = LogFactory.getLog(Hop.class.getName());
public static final long CPThreshold = 2000;
@@ -725,7 +724,6 @@
if (LOG.isDebugEnabled()){
String s = String.format(" %c %-5s %-8s (%s,%s) %s", c, getHopID(), getOpString(), OptimizerUtils.toMB(_outputMemEstimate), OptimizerUtils.toMB(_memEstimate), et);
- //System.out.println(s);
LOG.debug(s);
}
diff --git a/src/main/java/org/apache/sysds/hops/LiteralOp.java b/src/main/java/org/apache/sysds/hops/LiteralOp.java
index 61a7acb..31930e3 100644
--- a/src/main/java/org/apache/sysds/hops/LiteralOp.java
+++ b/src/main/java/org/apache/sysds/hops/LiteralOp.java
@@ -283,4 +283,9 @@
{
return false;
}
+
+ @Override
+ public String toString(){
+ return getOpString();
+ }
}
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index 2f2352d..568d5c7 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -19,6 +19,12 @@
package org.apache.sysds.hops;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.Map.Entry;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
@@ -39,17 +45,14 @@
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;
-import java.util.HashMap;
-import java.util.LinkedHashMap;
-import java.util.Map.Entry;
-
/**
* Defines the HOP for calling an internal function (with custom parameters) from a DML script.
*
*/
-public class ParameterizedBuiltinOp extends MultiThreadedHop
-{
+public class ParameterizedBuiltinOp extends MultiThreadedHop {
+ private static final Log LOG = LogFactory.getLog(ParameterizedBuiltinOp.class.getName());
+
public static boolean FORCE_DIST_RM_EMPTY = false;
//operator type
@@ -950,7 +953,7 @@
}
}
catch(Exception ex) {
- LOG.warn(ex.getMessage());
+ LOG.warn("Non Zero Replace Arguments exception: " + ex.getMessage());
}
return ret;
diff --git a/src/main/java/org/apache/sysds/lops/CSVReBlock.java b/src/main/java/org/apache/sysds/lops/CSVReBlock.java
index c188ba3..3005aee 100644
--- a/src/main/java/org/apache/sysds/lops/CSVReBlock.java
+++ b/src/main/java/org/apache/sysds/lops/CSVReBlock.java
@@ -19,11 +19,11 @@
package org.apache.sysds.lops;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.instructions.InstructionUtils;
-import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.common.Types.ValueType;
/**
@@ -69,6 +69,8 @@
String.valueOf(DataExpression.DEFAULT_DELIM_FILL));
Lop fillValueLop = dataInput.getNamedInputLop(DataExpression.DELIM_FILL_VALUE,
String.valueOf(DataExpression.DEFAULT_DELIM_FILL_VALUE));
+ Lop naStrings = dataInput.getNamedInputLop(DataExpression.DELIM_NA_STRINGS,
+ String.valueOf(DataExpression.DEFAULT_NA_STRINGS));
if (headerLop.isVariable())
throw new LopsException(this.printErrorLocation()
@@ -94,6 +96,16 @@
sb.append( ((Data)fillLop).getBooleanValue() );
sb.append( OPERAND_DELIMITOR );
sb.append( ((Data)fillValueLop).getDoubleValue() );
+ sb.append( OPERAND_DELIMITOR );
+ if(naStrings instanceof Nary){
+ Nary naLops = (Nary) naStrings;
+ for(Lop na : naLops.getInputs()){
+ sb.append(((Data)na).getStringValue());
+ sb.append(DataExpression.DELIM_NA_STRING_SEP);
+ }
+ } else if (naStrings instanceof Data){
+ sb.append(((Data)naStrings).getStringValue());
+ }
return sb.toString();
}
diff --git a/src/main/java/org/apache/sysds/lops/Data.java b/src/main/java/org/apache/sysds/lops/Data.java
index 60592f0..b2b6e0c 100644
--- a/src/main/java/org/apache/sysds/lops/Data.java
+++ b/src/main/java/org/apache/sysds/lops/Data.java
@@ -21,20 +21,17 @@
import java.util.HashMap;
-import org.apache.sysds.lops.LopProperties.ExecType;
-import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.ValueType;
-
-
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.parser.DataExpression;
/**
* Lop to represent data objects. Data objects represent matrices, vectors,
* variables, literals. Can be for both input and output.
*/
-
public class Data extends Lop
{
public static final String PREAD_PREFIX = "pREAD";
@@ -412,7 +409,6 @@
sb.append( OPERAND_DELIMITOR );
sb.append( (schema!=null) ? schema.prepScalarLabel() : "*" );
}
-
return sb.toString();
}
else {
@@ -433,7 +429,7 @@
Data delimLop = (Data) getNamedInputLop(DataExpression.DELIM_DELIMITER);
Data fillLop = (Data) getNamedInputLop(DataExpression.DELIM_FILL);
Data fillValueLop = (Data) getNamedInputLop(DataExpression.DELIM_FILL_VALUE);
- Data naLop = (Data) getNamedInputLop(DataExpression.DELIM_NA_STRINGS);
+ Lop naLop = getNamedInputLop(DataExpression.DELIM_NA_STRINGS);
sb.append(headerLop.getBooleanValue());
sb.append(OPERAND_DELIMITOR);
@@ -444,7 +440,16 @@
sb.append(fillValueLop.getDoubleValue());
if ( naLop != null ) {
sb.append(OPERAND_DELIMITOR);
- sb.append(naLop.getStringValue());
+ if(naLop instanceof Nary){
+ Nary naLops = (Nary) naLop;
+ for(Lop na : naLops.getInputs()){
+ sb.append(((Data)na).getStringValue());
+ sb.append(DataExpression.DELIM_NA_STRING_SEP);
+ }
+ } else if (naLop instanceof Data){
+
+ sb.append(((Data)naLop).getStringValue());
+ }
}
}
else { // (operation == OperationTypes.WRITE)
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java b/src/main/java/org/apache/sysds/lops/Lop.java
index 76f0caa..d81b6a8 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -21,13 +21,11 @@
import java.util.ArrayList;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
-import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.common.Types.ValueType;
/**
@@ -76,9 +74,6 @@
DONE, NOTVISITED
}
-
- protected static final Log LOG = LogFactory.getLog(Lop.class.getName());
-
public static final String FILE_SEPARATOR = "/";
public static final String PROCESS_PREFIX = "_p";
public static final String CP_ROOT_THREAD_ID = "_t0";
diff --git a/src/main/java/org/apache/sysds/lops/ReBlock.java b/src/main/java/org/apache/sysds/lops/ReBlock.java
index e248c39..399b442 100644
--- a/src/main/java/org/apache/sysds/lops/ReBlock.java
+++ b/src/main/java/org/apache/sysds/lops/ReBlock.java
@@ -19,19 +19,17 @@
package org.apache.sysds.lops;
-
-import org.apache.sysds.lops.LopProperties.ExecType;
-import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
/**
* Lop to perform reblock operation
*/
-public class ReBlock extends Lop
-{
+public class ReBlock extends Lop {
public static final String OPCODE = "rblk";
private boolean _outputEmptyBlocks = true;
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 09d58cc..a9e5a6e 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -1574,16 +1574,27 @@
}
else if (source instanceof DataIdentifier)
return hops.get(((DataIdentifier) source).getName());
+ else if (source instanceof ExpressionList){
+ ExpressionList sourceList = (ExpressionList) source;
+ List<Expression> expressions = sourceList.getValue();
+ Hop[] listHops = new Hop[expressions.size()];
+ int idx = 0;
+ for( Expression ex : expressions){
+ listHops[idx++] = processExpression(ex, null, hops);
+ }
+ Hop currBuiltinOp = HopRewriteUtils.createNary(OpOpN.LIST,listHops );
+ return currBuiltinOp;
+ }
+ else{
+ throw new ParseException("Unhandled instance of source type: " + source.getClass());
+ }
}
- catch ( Exception e ) {
- //print exception stacktrace for fatal exceptions w/o messages
- //to allow for error analysis other than ('no parse issue message')
- if( e.getMessage() == null )
- e.printStackTrace();
- throw new ParseException(e.getMessage());
+ catch(ParseException e ){
+ throw e;
}
-
- return null;
+ catch ( Exception e ) {
+ throw new ParseException("An Parsing exception occured", e);
+ }
}
private static DataIdentifier createTarget(Expression source) {
@@ -2220,7 +2231,10 @@
*/
private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, DataIdentifier target,
HashMap<String, Hop> hops) {
- Hop expr = processExpression(source.getFirstExpr(), null, hops);
+ Hop expr = null;
+ if(source.getFirstExpr() != null){
+ expr = processExpression(source.getFirstExpr(), null, hops);
+ }
Hop expr2 = null;
if (source.getSecondExpr() != null) {
expr2 = processExpression(source.getSecondExpr(), null, hops);
diff --git a/src/main/java/org/apache/sysds/parser/DataExpression.java b/src/main/java/org/apache/sysds/parser/DataExpression.java
index 4848a8e..52310cc 100644
--- a/src/main/java/org/apache/sysds/parser/DataExpression.java
+++ b/src/main/java/org/apache/sysds/parser/DataExpression.java
@@ -19,11 +19,23 @@
package org.apache.sysds.parser;
+import static org.apache.sysds.runtime.instructions.fed.InitFEDInstruction.FED_FRAME_IDENTIFIER;
+import static org.apache.sysds.runtime.instructions.fed.InitFEDInstruction.FED_MATRIX_IDENTIFIER;
+
+import java.io.BufferedReader;
+import java.io.InputStreamReader;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map.Entry;
+import java.util.Set;
+
import org.antlr.v4.runtime.ParserRuleContext;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.Path;
-import org.apache.wink.json4j.JSONArray;
-import org.apache.wink.json4j.JSONObject;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.FileFormat;
@@ -41,21 +53,13 @@
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.JSONHelper;
-
-import java.io.BufferedReader;
-import java.io.InputStreamReader;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Set;
-import java.util.Map.Entry;
-
-import static org.apache.sysds.runtime.instructions.fed.InitFEDInstruction.FED_FRAME_IDENTIFIER;
-import static org.apache.sysds.runtime.instructions.fed.InitFEDInstruction.FED_MATRIX_IDENTIFIER;
+import org.apache.wink.json4j.JSONArray;
+import org.apache.wink.json4j.JSONObject;
public class DataExpression extends DataIdentifier
{
+ private static final Log LOG = LogFactory.getLog(DataExpression.class.getName());
+
public static final String RAND_DIMS = "dims";
public static final String RAND_ROWS = "rows";
@@ -783,10 +787,9 @@
if (inputParamExpr instanceof FunctionCallIdentifier) {
raiseValidateError("UDF function call not supported as parameter to built-in function call", false,LanguageErrorCodes.INVALID_PARAMETERS);
}
-
inputParamExpr.validateExpression(ids, currConstVars, conditional);
- if ( getVarParam(s).getOutput().getDataType() != DataType.SCALAR && !s.equals(RAND_DATA) &&
- !s.equals(RAND_DIMS) && !s.equals(FED_ADDRESSES) && !s.equals(FED_RANGES)) {
+ if (s != null && !s.equals(RAND_DATA) && !s.equals(RAND_DIMS) && !s.equals(FED_ADDRESSES) && !s.equals(FED_RANGES)
+ && !s.equals(DELIM_NA_STRINGS) && getVarParam(s).getOutput().getDataType() != DataType.SCALAR ) {
raiseValidateError("Non-scalar data types are not supported for data expression.", conditional,LanguageErrorCodes.INVALID_PARAMETERS);
}
}
@@ -963,6 +966,8 @@
}
}
+ boolean isCSV = (formatTypeString != null && formatTypeString.equalsIgnoreCase(FileFormat.CSV.toString()));
+
if (shouldReadMTD){
configObject = readMetadataFile(mtdFileName, conditional);
// if the MTD file exists, check the values specified in read statement match values in metadata MTD file
@@ -971,32 +976,23 @@
inferredFormatType = true;
}
else {
- LOG.warn("Metadata file: " + new Path(mtdFileName) + " not provided");
+ if(!isCSV){
+ LOG.warn("Metadata file: " + new Path(mtdFileName) + " not provided");
+ }
}
}
- boolean isCSV = false;
- isCSV = (formatTypeString != null && formatTypeString.equalsIgnoreCase(FileFormat.CSV.toString()));
if (isCSV){
- // Handle delimited file format
- //
- // 1) only allow IO_FILENAME, _HEADER_ROW, FORMAT_DELIMITER, READROWPARAM, READCOLPARAM
- //
- // 2) open the file
- //
-
+
// there should be no MTD file for delimited file format
shouldReadMTD = true;
- // only allow IO_FILENAME, HAS_HEADER_ROW, FORMAT_DELIMITER, READROWPARAM, READCOLPARAM
- // as ONLY valid parameters
+ // Handle valid ParamNames.
if( !inferredFormatType ){
for (String key : _varParams.keySet()){
- if (! READ_VALID_PARAM_NAMES.contains(key)) {
- String msg = "Only parameters allowed are: " + Arrays.toString(new String[] {
- IO_FILENAME, FORMAT_TYPE, SCHEMAPARAM, DELIM_HAS_HEADER_ROW, DELIM_DELIMITER,
- DELIM_FILL, DELIM_FILL_VALUE, READNNZPARAM, READROWPARAM, DATATYPEPARAM,
- VALUETYPEPARAM, READCOLPARAM,DELIM_NA_STRINGS});
+ if (! READ_VALID_PARAM_NAMES.contains(key))
+ {
+ String msg = "Only parameters allowed are: " + READ_VALID_PARAM_NAMES;
raiseValidateError("Invalid parameter " + key + " in read statement: " +
toString() + ". " + msg, conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
@@ -1059,10 +1055,11 @@
}
else {
if ((getVarParam(DELIM_NA_STRINGS) instanceof ConstIdentifier)
- && (! (getVarParam(DELIM_NA_STRINGS) instanceof StringIdentifier)))
- {
- raiseValidateError("For delimited file '" + getVarParam(DELIM_NA_STRINGS) + "' must be a string value ", conditional);
- }
+ && (! (getVarParam(DELIM_NA_STRINGS) instanceof StringIdentifier)))
+ {
+ raiseValidateError("For delimited file '" + getVarParam(DELIM_NA_STRINGS) + "' must be a string value ", conditional);
+ }
+ LOG.info("Replacing :" + _varParams.get(DELIM_NA_STRINGS) + " with NaN");
}
}
diff --git a/src/main/java/org/apache/sysds/parser/Expression.java b/src/main/java/org/apache/sysds/parser/Expression.java
index 25eeade..059d093 100644
--- a/src/main/java/org/apache/sysds/parser/Expression.java
+++ b/src/main/java/org/apache/sysds/parser/Expression.java
@@ -70,7 +70,7 @@
INTERNAL, EXTERNAL
}
- protected static final Log LOG = LogFactory.getLog(Expression.class.getName());
+ private static final Log LOG = LogFactory.getLog(Expression.class.getName());
private static final IDSequence _tempId = new IDSequence();
protected Identifier[] _outputs;
diff --git a/src/main/java/org/apache/sysds/parser/ExpressionList.java b/src/main/java/org/apache/sysds/parser/ExpressionList.java
index ed402ce..90d50c6 100644
--- a/src/main/java/org/apache/sysds/parser/ExpressionList.java
+++ b/src/main/java/org/apache/sysds/parser/ExpressionList.java
@@ -20,50 +20,74 @@
package org.apache.sysds.parser;
import java.util.ArrayList;
+import java.util.HashMap;
public class ExpressionList extends Expression {
protected String _name;
protected ArrayList<Expression> _value;
-
+
public ExpressionList(ArrayList<Expression> value) {
this._name = "tmp";
this._value = value;
}
-
+
public String getName() {
return _name;
}
+
public void setName(String _name) {
this._name = _name;
}
+
public ArrayList<Expression> getValue() {
return _value;
}
+
public void setValue(ArrayList<Expression> _value) {
this._value = _value;
}
-
+
+ @Override
+ public void validateExpression(HashMap<String, DataIdentifier> ids, HashMap<String, ConstIdentifier> currConstVars,
+ boolean conditional) {
+ for(Expression ex : _value) {
+ ex.validateExpression(ids, currConstVars, conditional);
+ }
+ }
+
@Override
public Expression rewriteExpression(String prefix) {
throw new LanguageException("ExpressionList should not be exposed beyond parser layer.");
}
+
@Override
public VariableSet variablesRead() {
VariableSet result = new VariableSet();
- for( Expression expr : _value ) {
- result.addVariables ( expr.variablesRead() );
+ for(Expression expr : _value) {
+ result.addVariables(expr.variablesRead());
}
return result;
}
+
@Override
public VariableSet variablesUpdated() {
VariableSet result = new VariableSet();
- for( Expression expr : _value ) {
- result.addVariables ( expr.variablesUpdated() );
+ for(Expression expr : _value) {
+ result.addVariables(expr.variablesUpdated());
}
return result;
}
-
-}
\ No newline at end of file
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(super.toString());
+ sb.append("[");
+ for(Expression e : _value) {
+ sb.append(e);
+ }
+ sb.append("]");
+ return sb.toString();
+ }
+}
diff --git a/src/main/java/org/apache/sysds/parser/IndexedIdentifier.java b/src/main/java/org/apache/sysds/parser/IndexedIdentifier.java
index e6e1c7f..bc970cc 100644
--- a/src/main/java/org/apache/sysds/parser/IndexedIdentifier.java
+++ b/src/main/java/org/apache/sysds/parser/IndexedIdentifier.java
@@ -22,10 +22,14 @@
import java.util.ArrayList;
import java.util.HashMap;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.util.UtilFunctions;
public class IndexedIdentifier extends DataIdentifier
{
+
+ private static final Log LOG = LogFactory.getLog(DataExpression.class.getName());
// stores the expressions containing the ranges for the
private Expression _rowLowerBound = null, _rowUpperBound = null, _colLowerBound = null, _colUpperBound = null;
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index e1b0b98..5d19575 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -862,12 +862,17 @@
List<Expression> expressions = pstmt.getExpressions();
for (Expression expression : expressions) {
expression.validateExpression(ids.getVariables(), currConstVars, conditional);
- if (expression.getOutput().getDataType() != DataType.SCALAR) {
- if (expression.getOutput().getDataType() == DataType.MATRIX) {
- pstmt.raiseValidateError("Print statements can only print scalars. To print a matrix, please wrap it in a toString() function.", conditional);
- } else {
- pstmt.raiseValidateError("Print statements can only print scalars.", conditional);
- }
+ DataType outputDatatype = expression.getOutput().getDataType();
+ switch (outputDatatype) {
+ case SCALAR:
+ break;
+ case MATRIX:
+ case TENSOR:
+ case FRAME:
+ case LIST:
+ pstmt.raiseValidateError("Print statements can only print scalars. To print a " + outputDatatype + ", please wrap it in a toString() function.", conditional);
+ default:
+ pstmt.raiseValidateError("Print statements can only print scalars. Input datatype was: " + outputDatatype, conditional);
}
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java
index e5afdd1..7ff83f5 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java
@@ -23,6 +23,8 @@
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
+import java.util.Set;
+import java.util.StringTokenizer;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
@@ -30,9 +32,6 @@
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.util.ProgramConverter;
-import java.util.Set;
-import java.util.StringTokenizer;
-
/**
* Replaces <code>HashMap⟨String, Data⟩</code> as the table of
* variable names and references. No longer supports global consistency.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index bcc9755..3354eb8 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -58,7 +58,6 @@
import java.util.List;
import java.util.stream.Collectors;
-
public class ExecutionContext {
protected static final Log LOG = LogFactory.getLog(ExecutionContext.class.getName());
@@ -738,4 +737,17 @@
private static String getNonExistingVarError(String varname) {
return "Variable '" + varname + "' does not exist in the symbol table.";
}
+
+ @Override
+ public String toString(){
+ StringBuilder sb = new StringBuilder();
+ sb.append(super.toString());
+ if(_prog != null)
+ sb.append("\nProgram: " + _prog.toString());
+ if(_variables != null)
+ sb.append("\nLocalVariableMap: " + _variables.toString());
+ if(_lineage != null)
+ sb.append("\nLineage: " + _lineage.toString());
+ return sb.toString();
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
index b0e062d..bf0b514 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
@@ -71,23 +71,18 @@
*/
public class OptimizationWrapper
{
-
private static final boolean LDEBUG = false; //internal local debug level
private static final Log LOG = LogFactory.getLog(OptimizationWrapper.class.getName());
//internal parameters
public static final double PAR_FACTOR_INFRASTRUCTURE = 1.0;
- private static final boolean CHECK_PLAN_CORRECTNESS = false;
-
- static
- {
- // for internal debugging only
- if( LDEBUG ) {
- Logger.getLogger("org.apache.sysds.runtime.controlprogram.parfor.opt")
- .setLevel(Level.DEBUG);
- }
- }
+ private static final boolean CHECK_PLAN_CORRECTNESS = false;
+ static {
+ if( LDEBUG )
+ setLogLevel(Level.DEBUG);
+ }
+
/**
* Called once per top-level parfor (during runtime, on parfor execute)
* in order to optimize the specific parfor program block.
@@ -122,12 +117,9 @@
StatisticMonitor.putPFStat( pb.getID() , Stat.OPT_T, timeVal);
}
- public static void setLogLevel( Level optLogLevel )
- {
- if( !LDEBUG ){ //set log level if not overwritten by internal flag
- Logger.getLogger("org.apache.sysds.runtime.controlprogram.parfor.opt")
+ public static void setLogLevel( Level optLogLevel ) {
+ Logger.getLogger("org.apache.sysds.runtime.controlprogram.parfor.opt")
.setLevel( optLogLevel );
- }
}
@SuppressWarnings("unused")
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/Optimizer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/Optimizer.java
index 948bf97..d04f324 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/Optimizer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/Optimizer.java
@@ -19,8 +19,6 @@
package org.apache.sysds.runtime.controlprogram.parfor.opt;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.apache.sysds.parser.ParForStatementBlock;
import org.apache.sysds.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysds.runtime.controlprogram.ParForProgramBlock.POptMode;
@@ -38,8 +36,6 @@
*/
public abstract class Optimizer
{
- protected static final Log LOG = LogFactory.getLog(Optimizer.class.getName());
-
protected long _numTotalPlans = -1;
protected long _numEvaluatedPlans = -1;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerConstrained.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerConstrained.java
index e965ecc..b0f9ee5 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerConstrained.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerConstrained.java
@@ -22,6 +22,8 @@
import java.util.HashMap;
import java.util.HashSet;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.LopProperties;
@@ -57,8 +59,8 @@
* - 11) rewrite set result merge
*
*/
-public class OptimizerConstrained extends OptimizerRuleBased
-{
+public class OptimizerConstrained extends OptimizerRuleBased {
+ private static final Log LOG = LogFactory.getLog(OptimizerConstrained.class.getName());
@Override
public POptMode getOptMode() {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerHeuristic.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerHeuristic.java
index 4c08bf9..ba79de0 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerHeuristic.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerHeuristic.java
@@ -19,6 +19,8 @@
package org.apache.sysds.runtime.controlprogram.parfor.opt;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.ParForProgramBlock.POptMode;
import org.apache.sysds.runtime.controlprogram.parfor.opt.CostEstimator.TestMeasure;
@@ -30,8 +32,8 @@
*
*
*/
-public class OptimizerHeuristic extends OptimizerRuleBased
-{
+public class OptimizerHeuristic extends OptimizerRuleBased {
+ private static final Log LOG = LogFactory.getLog(OptimizerHeuristic.class.getName());
public static final double EXEC_TIME_THRESHOLD = 30000; //in ms
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
index 63ae8af..15cef32 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
@@ -19,6 +19,8 @@
package org.apache.sysds.runtime.controlprogram.parfor.opt;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.sysds.api.DMLScript;
@@ -137,8 +139,9 @@
* TODO blockwise partitioning
*
*/
-public class OptimizerRuleBased extends Optimizer
-{
+public class OptimizerRuleBased extends Optimizer {
+ private static final Log LOG = LogFactory.getLog(OptimizerRuleBased.class.getName());
+
public static final double PROB_SIZE_THRESHOLD_REMOTE = 100; //wrt # top-level iterations (min)
public static final double PROB_SIZE_THRESHOLD_PARTITIONING = 2; //wrt # top-level iterations (min)
public static final double PROB_SIZE_THRESHOLD_MB = 256*1024*1024; //wrt overall memory consumption (min)
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
index da0ab63..85f3717 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
@@ -37,7 +37,7 @@
FEDERATED
}
- protected static final Log LOG = LogFactory.getLog(Instruction.class.getName());
+ private static final Log LOG = LogFactory.getLog(Instruction.class.getName());
public static final String OPERAND_DELIM = Lop.OPERAND_DELIMITOR;
public static final String DATATYPE_PREFIX = Lop.DATATYPE_PREFIX;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
index 100251d..821a9ff 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
@@ -41,8 +41,9 @@
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.utils.Explain;
-public class AggregateUnaryCPInstruction extends UnaryCPInstruction
-{
+public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
+ // private static final Log LOG = LogFactory.getLog(AggregateUnaryCPInstruction.class.getName());
+
public enum AUType {
NROW, NCOL, LENGTH, EXISTS, LINEAGE,
COUNT_DISTINCT, COUNT_DISTINCT_APPROX,
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
index 6f92b8f..dc7a5f1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
@@ -20,6 +20,8 @@
package org.apache.sysds.runtime.instructions.cp;
import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.ValueType;
@@ -41,7 +43,7 @@
public class DataGenCPInstruction extends UnaryCPInstruction {
-
+ private static final Log LOG = LogFactory.getLog(DataGenCPInstruction.class.getName());
private OpOpDG method;
private final CPOperand rows, cols, dims;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java
index bc063c2..ba17e3c 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java
@@ -20,6 +20,9 @@
package org.apache.sysds.runtime.instructions.cp;
import java.util.ArrayList;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -27,13 +30,14 @@
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.data.DnnParameters;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNN;
+import org.apache.sysds.runtime.matrix.data.LibMatrixDNN.PoolingType;
import org.apache.sysds.runtime.matrix.data.LibMatrixNative;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.data.LibMatrixDNN.PoolingType;
import org.apache.sysds.runtime.util.DnnUtils;
import org.apache.sysds.utils.NativeHelper;
public class DnnCPInstruction extends UnaryCPInstruction {
+ private static final Log LOG = LogFactory.getLog(DnnCPInstruction.class.getName());
private static boolean warnedUnderUtilitization = false;
private final CPOperand _in2;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index 6be4761..ad24ced 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -25,11 +25,13 @@
import java.util.List;
import java.util.stream.Collectors;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataIdentifier;
-import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
@@ -49,6 +51,7 @@
import org.apache.sysds.utils.Statistics;
public class FunctionCallCPInstruction extends CPInstruction {
+ private static final Log LOG = LogFactory.getLog(FunctionCallCPInstruction.class.getName());
private final String _functionName;
private final String _namespace;
private final boolean _opt;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixAppendCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixAppendCPInstruction.java
index ea178bb..1b2e304 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixAppendCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixAppendCPInstruction.java
@@ -24,11 +24,10 @@
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
-import org.apache.sysds.runtime.lineage.LineageTraceable;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
-public final class MatrixAppendCPInstruction extends AppendCPInstruction implements LineageTraceable {
+public final class MatrixAppendCPInstruction extends AppendCPInstruction {
protected MatrixAppendCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out,
AppendType type, String opcode, String istr) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index f284240..f330793 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -28,6 +28,8 @@
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.lops.Lop;
@@ -35,7 +37,6 @@
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
-import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
@@ -59,6 +60,7 @@
import org.apache.sysds.runtime.util.DataConverter;
public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction {
+ private static final Log LOG = LogFactory.getLog(ParameterizedBuiltinCPInstruction.class.getName());
private static final int TOSTRING_MAXROWS = 100;
private static final int TOSTRING_MAXCOLS = 100;
private static final int TOSTRING_DECIMAL = 3;
@@ -327,29 +329,35 @@
//get input matrix/frame and convert to string
String out = null;
-
- CacheableData<?> cacheData = ec.getCacheableData(getParam("target"));
+
+ Data cacheData = ec.getVariable(getParam("target"));
if( cacheData instanceof MatrixObject) {
- MatrixBlock matrix = (MatrixBlock) cacheData.acquireRead();
+ MatrixBlock matrix = ((MatrixObject)cacheData).acquireRead();
warnOnTrunction(matrix, rows, cols);
out = DataConverter.toString(matrix, sparse, separator, lineSeparator, rows, cols, decimal);
}
else if( cacheData instanceof TensorObject ) {
- TensorBlock tensor = (TensorBlock) cacheData.acquireRead();
+ TensorBlock tensor = ((TensorObject)cacheData).acquireRead();
// TODO improve truncation to check all dimensions
warnOnTrunction(tensor, rows, cols);
out = DataConverter.toString(tensor, sparse, separator,
lineSeparator, "[", "]", rows, cols, decimal);
}
else if( cacheData instanceof FrameObject ) {
- FrameBlock frame = (FrameBlock) cacheData.acquireRead();
+ FrameBlock frame = ((FrameObject) cacheData).acquireRead();
warnOnTrunction(frame, rows, cols);
out = DataConverter.toString(frame, sparse, separator, lineSeparator, rows, cols, decimal);
}
- else {
- throw new DMLRuntimeException("toString only converts matrix, tensors or frames to string");
+ else if (cacheData instanceof ListObject){
+ out = DataConverter.toString((ListObject) cacheData, rows, cols,
+ sparse, separator, lineSeparator, rows, cols, decimal);
}
- ec.releaseCacheableData(getParam("target"));
+ else {
+ throw new DMLRuntimeException("toString only converts matrix, tensors, lists or frames to string");
+ }
+ if(!(cacheData instanceof ListObject)){
+ ec.releaseCacheableData(getParam("target"));
+ }
ec.setScalarOutput(output.getName(), new StringObject(out));
}
else if( opcode.equals("nvlist") ) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 870325c..c42ec91 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -19,10 +19,6 @@
package org.apache.sysds.runtime.instructions.cp;
-import static org.apache.sysds.parser.Statement.PSFrequency;
-import static org.apache.sysds.parser.Statement.PSModeType;
-import static org.apache.sysds.parser.Statement.PSScheme;
-import static org.apache.sysds.parser.Statement.PSUpdateType;
import static org.apache.sysds.parser.Statement.PS_AGGREGATION_FUN;
import static org.apache.sysds.parser.Statement.PS_BATCH_SIZE;
import static org.apache.sysds.parser.Statement.PS_EPOCHS;
@@ -49,13 +45,17 @@
import java.util.stream.IntStream;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;
-import org.apache.log4j.Level;
-import org.apache.log4j.Logger;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.util.LongAccumulator;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.parser.Statement.PSFrequency;
+import org.apache.sysds.parser.Statement.PSModeType;
+import org.apache.sysds.parser.Statement.PSScheme;
+import org.apache.sysds.parser.Statement.PSUpdateType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -77,24 +77,14 @@
import org.apache.sysds.utils.Statistics;
public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruction {
-
+ private static final Log LOG = LogFactory.getLog(ParamservBuiltinCPInstruction.class.getName());
+
private static final int DEFAULT_BATCH_SIZE = 64;
private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = PSFrequency.EPOCH;
private static final PSScheme DEFAULT_SCHEME = PSScheme.DISJOINT_CONTIGUOUS;
private static final PSModeType DEFAULT_MODE = PSModeType.LOCAL;
private static final PSUpdateType DEFAULT_TYPE = PSUpdateType.ASP;
- //internal local debug level
- private static final boolean LDEBUG = false;
-
- static {
- // for internal debugging only
- if (LDEBUG) {
- Logger.getLogger("org.apache.sysds.runtime.controlprogram.paramserv").setLevel(Level.DEBUG);
- Logger.getLogger(ParamservBuiltinCPInstruction.class.getName()).setLevel(Level.DEBUG);
- }
- }
-
public ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out, String opcode, String istr) {
super(op, paramsMap, out, opcode, istr);
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
index 6cc83a0..2082047 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
@@ -19,6 +19,10 @@
package org.apache.sysds.runtime.instructions.cp;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.fs.FileSystem;
@@ -65,12 +69,7 @@
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
public class VariableCPInstruction extends CPInstruction implements LineageTraceable {
-
/*
* Supported Operations
* --------------------
@@ -420,7 +419,7 @@
boolean hasHeader = Boolean.parseBoolean(parts[curPos]);
String delim = parts[curPos+1];
boolean fill = Boolean.parseBoolean(parts[curPos+2]);
- double fillValue = UtilFunctions.parseToDouble(parts[curPos+3]);
+ double fillValue = UtilFunctions.parseToDouble(parts[curPos+3],UtilFunctions.defaultNaString);
String naStrings = null;
if ( parts.length == 16+extSchema )
naStrings = parts[curPos+4];
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/gpu/GPUInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/gpu/GPUInstruction.java
index 4426bc0..2a63b8e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/gpu/GPUInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/gpu/GPUInstruction.java
@@ -19,6 +19,8 @@
package org.apache.sysds.runtime.instructions.gpu;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -30,6 +32,8 @@
import org.apache.sysds.utils.Statistics;
public abstract class GPUInstruction extends Instruction {
+ private static final Log LOG = LogFactory.getLog(GPUInstruction.class.getName());
+
public enum GPUINSTRUCTION_TYPE {
AggregateUnary,
AggregateBinary,
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java
index 1091ef5..9a3ae12 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java
@@ -19,6 +19,8 @@
package org.apache.sysds.runtime.instructions.gpu;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -30,7 +32,8 @@
import org.apache.sysds.utils.GPUStatistics;
public class MatrixBuiltinGPUInstruction extends BuiltinUnaryGPUInstruction {
-
+ private static final Log LOG = LogFactory.getLog(MatrixBuiltinGPUInstruction.class.getName());
+
protected MatrixBuiltinGPUInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String instr) {
super(op, in, out, 1, opcode, instr);
_gputype = GPUINSTRUCTION_TYPE.BuiltinUnary;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/CSVReblockSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/CSVReblockSPInstruction.java
index b8d12d2..d073a3c 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/CSVReblockSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/CSVReblockSPInstruction.java
@@ -19,6 +19,9 @@
package org.apache.sysds.runtime.instructions.spark;
+import java.util.HashSet;
+import java.util.Set;
+
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.spark.api.java.JavaPairRDD;
@@ -26,6 +29,7 @@
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
@@ -49,9 +53,10 @@
private String _delim;
private boolean _fill;
private double _fillValue;
+ private Set<String> _naStrings;
protected CSVReblockSPInstruction(Operator op, CPOperand in, CPOperand out, int br, int bc, boolean hasHeader,
- String delim, boolean fill, double fillValue, String opcode, String instr) {
+ String delim, boolean fill, double fillValue, String opcode, String instr, Set<String> naStrings) {
super(SPType.CSVReblock, op, in, out, opcode, instr);
_blen = br;
_blen = bc;
@@ -59,6 +64,7 @@
_delim = delim;
_fill = fill;
_fillValue = fillValue;
+ _naStrings = naStrings;
}
public static CSVReblockSPInstruction parseInstruction(String str) {
@@ -79,8 +85,14 @@
boolean fill = Boolean.parseBoolean(parts[6]);
double fillValue = Double.parseDouble(parts[7]);
+ // Set<String> naStrings = UtilFunctions.defaultNaString;
+ Set<String> naStrings = new HashSet<>();
+ for(String s:parts[8].split(DataExpression.DELIM_NA_STRING_SEP)){
+ naStrings.add(s);
+ }
+
return new CSVReblockSPInstruction(null, in, out, blen, blen,
- hasHeader, delim, fill, fillValue, opcode, str);
+ hasHeader, delim, fill, fillValue, opcode, str, naStrings);
}
@Override
@@ -131,7 +143,7 @@
//reblock csv to binary block
return RDDConverterUtils.csvToBinaryBlock(sec.getSparkContext(),
- in, mcOut, _hasHeader, _delim, _fill, _fillValue);
+ in, mcOut, _hasHeader, _delim, _fill, _fillValue, _naStrings);
}
@SuppressWarnings("unchecked")
@@ -143,6 +155,6 @@
//reblock csv to binary block
return FrameRDDConverterUtils.csvToBinaryBlock(sec.getSparkContext(),
- in, mcOut, schema, _hasHeader, _delim, _fill, _fillValue);
+ in, mcOut, schema, _hasHeader, _delim, _fill, _fillValue, _naStrings);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
index 4dce9c1..c8ff6b6 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
@@ -20,6 +20,11 @@
package org.apache.sysds.runtime.instructions.spark;
+import java.util.Iterator;
+import java.util.stream.IntStream;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
@@ -48,12 +53,12 @@
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
+
import scala.Tuple2;
-import java.util.Iterator;
-import java.util.stream.IntStream;
-
public class MapmmSPInstruction extends BinarySPInstruction {
+ private static final Log LOG = LogFactory.getLog(MapmmSPInstruction.class.getName());
+
private CacheType _type = null;
private boolean _outputEmpty = true;
private SparkAggType _aggtype;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
index 17315f0..dac0aa5 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
@@ -19,7 +19,17 @@
package org.apache.sysds.runtime.instructions.spark;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.Random;
+
import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.commons.math3.random.Well1024a;
import org.apache.hadoop.fs.FileSystem;
@@ -35,11 +45,11 @@
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.DataGenOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.DataGen;
import org.apache.sysds.lops.Lop;
-import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
@@ -64,18 +74,12 @@
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;
+
import scala.Array;
import scala.Tuple2;
-import java.io.IOException;
-import java.io.PrintWriter;
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Iterator;
-import java.util.Random;
-
public class RandSPInstruction extends UnarySPInstruction {
+ private static final Log LOG = LogFactory.getLog(RandSPInstruction.class.getName());
// internal configuration
private static final long INMEMORY_NUMBLOCKS_THRESHOLD = 1024 * 1024;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReblockSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReblockSPInstruction.java
index cf0d162..46ab52e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReblockSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReblockSPInstruction.java
@@ -19,6 +19,8 @@
package org.apache.sysds.runtime.instructions.spark;
+import java.util.Set;
+
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.spark.api.java.JavaPairRDD;
@@ -141,6 +143,7 @@
String delim = ",";
boolean fill = false;
double fillValue = 0;
+ Set<String> naStrings = null;
if(mo.getFileFormatProperties() instanceof FileFormatPropertiesCSV
&& mo.getFileFormatProperties() != null )
{
@@ -149,9 +152,10 @@
delim = props.getDelim();
fill = props.isFill();
fillValue = props.getFillValue();
+ naStrings = props.getNAStrings();
}
- csvInstruction = new CSVReblockSPInstruction(null, input1, output, mcOut.getBlocksize(), mcOut.getBlocksize(), hasHeader, delim, fill, fillValue, "csvrblk", instString);
+ csvInstruction = new CSVReblockSPInstruction(null, input1, output, mcOut.getBlocksize(), mcOut.getBlocksize(), hasHeader, delim, fill, fillValue, "csvrblk", instString, naStrings);
csvInstruction.processInstruction(sec);
return;
}
@@ -214,6 +218,7 @@
String delim = ",";
boolean fill = false;
double fillValue = 0;
+ Set<String> naStrings = null;
if(fo.getFileFormatProperties() instanceof FileFormatPropertiesCSV
&& fo.getFileFormatProperties() != null )
{
@@ -222,9 +227,10 @@
delim = props.getDelim();
fill = props.isFill();
fillValue = props.getFillValue();
+ naStrings = props.getNAStrings();
}
- csvInstruction = new CSVReblockSPInstruction(null, input1, output, mcOut.getBlocksize(), mcOut.getBlocksize(), hasHeader, delim, fill, fillValue, "csvrblk", instString);
+ csvInstruction = new CSVReblockSPInstruction(null, input1, output, mcOut.getBlocksize(), mcOut.getBlocksize(), hasHeader, delim, fill, fillValue, "csvrblk", instString, naStrings);
csvInstruction.processInstruction(sec);
}
else {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
index a74162f..5b6f2e4 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
@@ -19,6 +19,11 @@
package org.apache.sysds.runtime.instructions.spark;
+import java.util.ArrayList;
+import java.util.Iterator;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
@@ -52,12 +57,12 @@
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
+
import scala.Tuple2;
-import java.util.ArrayList;
-import java.util.Iterator;
-
public class ReorgSPInstruction extends UnarySPInstruction {
+ private static final Log LOG = LogFactory.getLog(ReorgSPInstruction.class.getName());
+
// sort-specific attributes (to enable variable attributes)
private CPOperand _col = null;
private CPOperand _desc = null;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java
index e87e302..1d9e484 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java
@@ -19,6 +19,13 @@
package org.apache.sysds.runtime.instructions.spark.utils;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.io.LongWritable;
@@ -61,13 +68,8 @@
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.FastStringTokenizer;
import org.apache.sysds.runtime.util.UtilFunctions;
-import scala.Tuple2;
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Iterator;
-import java.util.List;
+import scala.Tuple2;
@@ -80,7 +82,7 @@
public static JavaPairRDD<Long, FrameBlock> csvToBinaryBlock(JavaSparkContext sc,
JavaPairRDD<LongWritable, Text> input, DataCharacteristics mc, ValueType[] schema,
- boolean hasHeader, String delim, boolean fill, double fillValue)
+ boolean hasHeader, String delim, boolean fill, double fillValue, Set<String> naStrings)
{
//determine unknown dimensions and sparsity if required
if( !mc.dimsKnown() ) { //nnz irrelevant here
@@ -105,21 +107,21 @@
//convert csv rdd to binary block rdd (w/ partial blocks)
JavaPairRDD<Long, FrameBlock> out = prepinput.mapPartitionsToPair(
- new CSVToBinaryBlockFunction(mc, schema, hasHeader, delim));
+ new CSVToBinaryBlockFunction(mc, schema, hasHeader, delim, naStrings));
return out;
}
public static JavaPairRDD<Long, FrameBlock> csvToBinaryBlock(JavaSparkContext sc,
JavaRDD<String> input, DataCharacteristics mcOut, ValueType[] schema,
- boolean hasHeader, String delim, boolean fill, double fillValue)
+ boolean hasHeader, String delim, boolean fill, double fillValue, Set<String> naStrings)
{
//convert string rdd to serializable longwritable/text
JavaPairRDD<LongWritable, Text> prepinput =
input.mapToPair(new StringToSerTextFunction());
//convert to binary block
- return csvToBinaryBlock(sc, prepinput, mcOut, schema, hasHeader, delim, fill, fillValue);
+ return csvToBinaryBlock(sc, prepinput, mcOut, schema, hasHeader, delim, fill, fillValue, naStrings);
}
public static JavaRDD<String> binaryBlockToCsv(JavaPairRDD<Long,FrameBlock> in,
@@ -549,13 +551,15 @@
private String[] _colnames = null;
private List<String> _mvMeta = null; //missing value meta data
private List<String> _ndMeta = null; //num distinct meta data
+ private Set<String> _naStrings;
- public CSVToBinaryBlockFunction(DataCharacteristics mc, ValueType[] schema, boolean hasHeader, String delim) {
+ public CSVToBinaryBlockFunction(DataCharacteristics mc, ValueType[] schema, boolean hasHeader, String delim, Set<String> naStrings) {
_clen = mc.getCols();
_schema = schema;
_hasHeader = hasHeader;
_delim = delim;
_maxRowsPerBlock = Math.max((int) (FrameBlock.BUFFER_SIZE/_clen), 1);
+ _naStrings = naStrings;
}
@Override
@@ -597,7 +601,7 @@
}
//split and process row data
- fb.appendRow(IOUtilFunctions.splitCSV(row, _delim, tmprow));
+ fb.appendRow(IOUtilFunctions.splitCSV(row, _delim, tmprow, _naStrings));
}
//flush last blocks
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/RDDConverterUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/RDDConverterUtils.java
index 94b668c..e789398 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/RDDConverterUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/RDDConverterUtils.java
@@ -19,6 +19,13 @@
package org.apache.sysds.runtime.instructions.spark.utils;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
@@ -66,16 +73,12 @@
import org.apache.sysds.runtime.util.FastStringTokenizer;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.runtime.util.UtilFunctions;
+
import scala.Tuple2;
-import java.io.IOException;
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
+public class RDDConverterUtils {
+ // private static final Log LOG = LogFactory.getLog(RDDConverterUtils.class.getName());
-public class RDDConverterUtils
-{
public static final String DF_ID_COLUMN = "__INDEX";
public static JavaPairRDD<MatrixIndexes, MatrixBlock> textCellToBinaryBlock(JavaSparkContext sc,
@@ -164,7 +167,8 @@
public static JavaPairRDD<MatrixIndexes, MatrixBlock> csvToBinaryBlock(JavaSparkContext sc,
JavaPairRDD<LongWritable, Text> input, DataCharacteristics mc,
- boolean hasHeader, String delim, boolean fill, double fillValue) {
+ boolean hasHeader, String delim, boolean fill, double fillValue, Set<String> naStrings) {
+
//determine unknown dimensions and sparsity if required
//(w/ robustness for mistakenly counted header in nnz)
if( !mc.dimsKnown(true) ) {
@@ -185,7 +189,7 @@
boolean sparse = requiresSparseAllocation(prepinput, mc);
JavaPairRDD<MatrixIndexes, MatrixBlock> out =
prepinput.mapPartitionsToPair(new CSVToBinaryBlockFunction(
- mc, sparse, hasHeader, delim, fill, fillValue));
+ mc, sparse, hasHeader, delim, fill, fillValue, naStrings));
//aggregate partial matrix blocks (w/ preferred number of output
//partitions as the data is likely smaller in binary block format,
@@ -196,14 +200,14 @@
public static JavaPairRDD<MatrixIndexes, MatrixBlock> csvToBinaryBlock(JavaSparkContext sc,
JavaRDD<String> input, DataCharacteristics mcOut,
- boolean hasHeader, String delim, boolean fill, double fillValue)
+ boolean hasHeader, String delim, boolean fill, double fillValue, Set<String> naStrings)
{
//convert string rdd to serializable longwritable/text
JavaPairRDD<LongWritable, Text> prepinput =
input.mapToPair(new StringToSerTextFunction());
//convert to binary block
- return csvToBinaryBlock(sc, prepinput, mcOut, hasHeader, delim, fill, fillValue);
+ return csvToBinaryBlock(sc, prepinput, mcOut, hasHeader, delim, fill, fillValue, naStrings);
}
public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlock(JavaSparkContext sc,
@@ -656,8 +660,9 @@
private String _delim = null;
private boolean _fill = false;
private double _fillValue = 0;
+ private Set<String> _naStrings;
- public CSVToBinaryBlockFunction(DataCharacteristics mc, boolean sparse, boolean hasHeader, String delim, boolean fill, double fillValue)
+ public CSVToBinaryBlockFunction(DataCharacteristics mc, boolean sparse, boolean hasHeader, String delim, boolean fill, double fillValue, Set<String> naStrings)
{
_rlen = mc.getRows();
_clen = mc.getCols();
@@ -668,20 +673,19 @@
_delim = delim;
_fill = fill;
_fillValue = fillValue;
+ _naStrings = naStrings == null ? UtilFunctions.defaultNaString : naStrings;
}
@Override
public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<Text,Long>> arg0)
- throws Exception
- {
+ throws Exception {
ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret = new ArrayList<>();
int ncblks = (int)Math.ceil((double)_clen/_blen);
MatrixIndexes[] ix = new MatrixIndexes[ncblks];
MatrixBlock[] mb = new MatrixBlock[ncblks];
-
- while( arg0.hasNext() )
- {
+
+ while(arg0.hasNext()){
Tuple2<Text,Long> tmp = arg0.next();
String row = tmp._1().toString();
long rowix = tmp._2() + (_header ? 0 : 1);
@@ -713,10 +717,10 @@
mb[cix-1].getSparseBlock().allocate(pos, lnnz);
}
for( int j=0; j<lclen; j++ ) {
- String part = parts[pix++];
+ String part = parts[pix++].trim();
emptyFound |= part.isEmpty() && !_fill;
double val = (part.isEmpty() && _fill) ?
- _fillValue : UtilFunctions.parseToDouble(part);
+ _fillValue : UtilFunctions.parseToDouble(part, _naStrings);
mb[cix-1].appendValue(pos, j, val);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/io/FileFormatPropertiesCSV.java b/src/main/java/org/apache/sysds/runtime/io/FileFormatPropertiesCSV.java
index 6c6de00..7049918 100644
--- a/src/main/java/org/apache/sysds/runtime/io/FileFormatPropertiesCSV.java
+++ b/src/main/java/org/apache/sysds/runtime/io/FileFormatPropertiesCSV.java
@@ -20,10 +20,12 @@
package org.apache.sysds.runtime.io;
import java.io.Serializable;
+import java.util.HashSet;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.runtime.util.UtilFunctions;
public class FileFormatPropertiesCSV extends FileFormatProperties implements Serializable
{
@@ -34,7 +36,7 @@
private String delim;
private boolean fill;
private double fillValue;
- private String naStrings;
+ private HashSet<String> naStrings;
private boolean sparse;
@@ -45,9 +47,10 @@
this.fill = DataExpression.DEFAULT_DELIM_FILL;
this.fillValue = DataExpression.DEFAULT_DELIM_FILL_VALUE;
this.sparse = DataExpression.DEFAULT_DELIM_SPARSE;
- this.naStrings = null;
- if( LOG.isDebugEnabled() )
- LOG.debug("FileFormatPropertiesCSV: " + toString());
+ this.naStrings = UtilFunctions.defaultNaString;
+ if(LOG.isDebugEnabled())
+ LOG.debug("FileFormatPropertiesCSV: " + this.toString());
+
}
public FileFormatPropertiesCSV(boolean hasHeader, String delim, boolean fill, double fillValue, String naStrings) {
@@ -55,17 +58,23 @@
this.delim = delim;
this.fill = fill;
this.fillValue = fillValue;
- this.naStrings = naStrings;
- if( LOG.isDebugEnabled() )
- LOG.debug("FileFormatPropertiesCSV full settings: " + toString());
+
+ this.naStrings = new HashSet<>();
+ for(String s: naStrings.split(DataExpression.DELIM_NA_STRING_SEP)){
+ this.naStrings.add(s);
+ }
+ if(LOG.isDebugEnabled())
+ LOG.debug("FileFormatPropertiesCSV full settings: " + this.toString());
}
public FileFormatPropertiesCSV(boolean hasHeader, String delim, boolean sparse) {
this.header = hasHeader;
this.delim = delim;
this.sparse = sparse;
- if( LOG.isDebugEnabled() )
- LOG.debug("FileFormatPropertiesCSV medium settings: " + toString());
+ this.naStrings = UtilFunctions.defaultNaString;
+ if(LOG.isDebugEnabled()){
+ LOG.debug("FileFormatPropertiesCSV medium settings: " + this.toString());
+ }
}
public boolean hasHeader() {
@@ -80,7 +89,7 @@
return delim;
}
- public String getNAStrings() {
+ public HashSet<String> getNAStrings() {
return naStrings;
}
diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java b/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java
index e83b16e..87e81c8 100644
--- a/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java
+++ b/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java
@@ -21,6 +21,7 @@
import java.io.IOException;
import java.io.InputStream;
+import java.util.Set;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
@@ -47,7 +48,7 @@
*
*/
public class FrameReaderTextCSV extends FrameReader {
- protected FileFormatPropertiesCSV _props = null;
+ protected FileFormatPropertiesCSV _props;
public FrameReaderTextCSV(FileFormatPropertiesCSV props) {
_props = props;
@@ -119,6 +120,7 @@
boolean isFill = _props.isFill();
double dfillValue = _props.getFillValue();
String sfillValue = String.valueOf(_props.getFillValue());
+ Set<String> naValues = _props.getNAStrings();
String delim = _props.getDelim();
// create record reader
@@ -158,7 +160,7 @@
for(String part : parts) // foreach cell
{
part = part.trim();
- if(part.isEmpty()) {
+ if(part.isEmpty() || naValues.contains(part)) {
if(isFill && dfillValue != 0)
dest.set(row, col, UtilFunctions.stringToObject(schema[col], sfillValue));
emptyValuesFound = true;
diff --git a/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java b/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java
index d476768..75095b2 100644
--- a/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java
@@ -31,6 +31,7 @@
import java.util.Arrays;
import java.util.Comparator;
import java.util.LinkedList;
+import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
@@ -243,9 +244,10 @@
* @param str string to split
* @param delim delimiter
* @param tokens array for tokens, length needs to match the number of tokens
+ * @param naStrings the strings to map to null value.
* @return string array of tokens
*/
- public static String[] splitCSV(String str, String delim, String[] tokens)
+ public static String[] splitCSV(String str, String delim, String[] tokens, Set<String> naStrings)
{
// check for empty input
if( str == null || str.isEmpty() )
@@ -255,6 +257,7 @@
int from = 0, to = 0;
int len = str.length();
int dlen = delim.length();
+ String curString;
int pos = 0;
while( from < len ) { // for all tokens
if( str.charAt(from) == CSV_QUOTE_CHAR
@@ -277,7 +280,8 @@
// slice out token and advance position
to = (to >= 0) ? to : len;
- tokens[pos++] = str.substring(from, to);
+ curString = str.substring(from, to);
+ tokens[pos++] = (naStrings.contains(curString)) ? null: curString;
from = to + delim.length();
}
diff --git a/src/main/java/org/apache/sysds/runtime/io/ReaderTextCSV.java b/src/main/java/org/apache/sysds/runtime/io/ReaderTextCSV.java
index 40088d0..63d0b04 100644
--- a/src/main/java/org/apache/sysds/runtime/io/ReaderTextCSV.java
+++ b/src/main/java/org/apache/sysds/runtime/io/ReaderTextCSV.java
@@ -25,6 +25,7 @@
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashSet;
import java.util.List;
import org.apache.commons.lang.StringUtils;
@@ -41,7 +42,7 @@
public class ReaderTextCSV extends MatrixReader
{
- private FileFormatPropertiesCSV _props = null;
+ private final FileFormatPropertiesCSV _props;
public ReaderTextCSV(FileFormatPropertiesCSV props) {
_props = props;
@@ -66,7 +67,7 @@
//core read
ret = readCSVMatrixFromHDFS(path, job, fs, ret, rlen, clen, blen,
- _props.hasHeader(), _props.getDelim(), _props.isFill(), _props.getFillValue() );
+ _props.hasHeader(), _props.getDelim(), _props.isFill(), _props.getFillValue(), _props.getNAStrings() );
//finally check if change of sparse/dense block representation required
//(nnz explicitly maintained during read)
@@ -84,7 +85,7 @@
//core read
long lnnz = readCSVMatrixFromInputStream(is, "external inputstream", ret, new MutableInt(0), rlen, clen,
- blen, _props.hasHeader(), _props.getDelim(), _props.isFill(), _props.getFillValue(), true);
+ blen, _props.hasHeader(), _props.getDelim(), _props.isFill(), _props.getFillValue(), true, _props.getNAStrings());
//finally check if change of sparse/dense block representation required
ret.setNonZeros( lnnz );
@@ -95,7 +96,7 @@
@SuppressWarnings("unchecked")
private static MatrixBlock readCSVMatrixFromHDFS( Path path, JobConf job, FileSystem fs, MatrixBlock dest,
- long rlen, long clen, int blen, boolean hasHeader, String delim, boolean fill, double fillValue )
+ long rlen, long clen, int blen, boolean hasHeader, String delim, boolean fill, double fillValue, HashSet<String> naStrings )
throws IOException, DMLRuntimeException
{
//prepare file paths in alphanumeric order
@@ -119,7 +120,7 @@
MutableInt row = new MutableInt(0);
for(int fileNo=0; fileNo<files.size(); fileNo++) {
lnnz += readCSVMatrixFromInputStream(fs.open(files.get(fileNo)), path.toString(), dest,
- row, rlen, clen, blen, hasHeader, delim, fill, fillValue, fileNo==0);
+ row, rlen, clen, blen, hasHeader, delim, fill, fillValue, fileNo==0, naStrings);
}
//post processing
@@ -129,7 +130,7 @@
}
private static long readCSVMatrixFromInputStream( InputStream is, String srcInfo, MatrixBlock dest, MutableInt rowPos,
- long rlen, long clen, int blen, boolean hasHeader, String delim, boolean fill, double fillValue, boolean first )
+ long rlen, long clen, int blen, boolean hasHeader, String delim, boolean fill, double fillValue, boolean first, HashSet<String> naStrings )
throws IOException
{
boolean sparse = dest.isInSparseFormat();
@@ -163,7 +164,7 @@
cellValue = fillValue;
}
else {
- cellValue = UtilFunctions.parseToDouble(part);
+ cellValue = UtilFunctions.parseToDouble(part, naStrings);
}
if ( cellValue != 0 ) {
dest.appendValue(row, col, cellValue);
@@ -193,7 +194,7 @@
cellValue = fillValue;
}
else {
- cellValue = UtilFunctions.parseToDouble(part);
+ cellValue = UtilFunctions.parseToDouble(part, naStrings);
}
if ( cellValue != 0 ) {
a.set(row, col, cellValue);
diff --git a/src/main/java/org/apache/sysds/runtime/io/ReaderTextCSVParallel.java b/src/main/java/org/apache/sysds/runtime/io/ReaderTextCSVParallel.java
index 7c37f47..26a2eaf 100644
--- a/src/main/java/org/apache/sysds/runtime/io/ReaderTextCSVParallel.java
+++ b/src/main/java/org/apache/sysds/runtime/io/ReaderTextCSVParallel.java
@@ -22,6 +22,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
+import java.util.HashSet;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
@@ -100,7 +101,7 @@
// Second Read Pass (read, parse strings, append to matrix block)
readCSVMatrixFromHDFS(splits, path, job, ret, rlen, clen, blen,
_props.hasHeader(), _props.getDelim(), _props.isFill(),
- _props.getFillValue());
+ _props.getFillValue(), _props.getNAStrings());
//post-processing (representation-specific, change of sparse/dense block representation)
// - no sorting required for CSV because it is read in sorted order per row
@@ -126,7 +127,7 @@
private void readCSVMatrixFromHDFS(InputSplit[] splits, Path path, JobConf job,
MatrixBlock dest, long rlen, long clen, int blen,
- boolean hasHeader, String delim, boolean fill, double fillValue)
+ boolean hasHeader, String delim, boolean fill, double fillValue, HashSet<String> naStrings)
throws IOException
{
FileInputFormat.addInputPath(job, path);
@@ -142,7 +143,7 @@
int splitCount = 0;
for (InputSplit split : splits) {
tasks.add( new CSVReadTask(split, _offsets, informat, job, dest,
- rlen, clen, hasHeader, delim, fill, fillValue, splitCount++) );
+ rlen, clen, hasHeader, delim, fill, fillValue, splitCount++, naStrings) );
}
pool.invokeAll(tasks);
pool.shutdown();
@@ -283,11 +284,12 @@
private boolean _rc = true;
private Exception _exception = null;
private long _nnz;
+ private HashSet<String> _naStrings;
public CSVReadTask(InputSplit split, SplitOffsetInfos offsets,
TextInputFormat informat, JobConf job, MatrixBlock dest,
long rlen, long clen, boolean hasHeader, String delim,
- boolean fill, double fillValue, int splitCount)
+ boolean fill, double fillValue, int splitCount, HashSet<String> naStrings)
{
_split = split;
_splitoffsets = offsets; // new SplitOffsetInfos(offsets);
@@ -304,6 +306,7 @@
_delim = delim;
_rc = true;
_splitCount = splitCount;
+ _naStrings = naStrings;
}
public boolean getReturnCode() {
@@ -358,7 +361,7 @@
cellValue = _fillValue;
}
else {
- cellValue = UtilFunctions.parseToDouble(part);
+ cellValue = UtilFunctions.parseToDouble(part,_naStrings);
}
if( cellValue != 0 ) {
@@ -389,7 +392,7 @@
cellValue = _fillValue;
}
else {
- cellValue = UtilFunctions.parseToDouble(part);
+ cellValue = UtilFunctions.parseToDouble(part,_naStrings);
}
if( cellValue != 0 ) {
a.set(row, col, cellValue);
diff --git a/src/main/java/org/apache/sysds/runtime/io/ReaderTextLIBSVM.java b/src/main/java/org/apache/sysds/runtime/io/ReaderTextLIBSVM.java
index 403858b..7694b69 100644
--- a/src/main/java/org/apache/sysds/runtime/io/ReaderTextLIBSVM.java
+++ b/src/main/java/org/apache/sysds/runtime/io/ReaderTextLIBSVM.java
@@ -174,7 +174,7 @@
for( int i=1; i<parts.length; i++ ) {
//parse non-zero: <index#>:<value#>
String[] pair = parts[i].split(IOUtilFunctions.LIBSVM_INDEX_DELIM);
- vect.append(Integer.parseInt(pair[0])-1, UtilFunctions.parseToDouble(pair[1]));
+ vect.append(Integer.parseInt(pair[0])-1, UtilFunctions.parseToDouble(pair[1],UtilFunctions.defaultNaString));
}
vect.append(clen-1, label);
return vect.size();
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index 5cbd3de..325819b 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -19,6 +19,22 @@
package org.apache.sysds.runtime.matrix.data;
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.io.Serializable;
+import java.lang.ref.SoftReference;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.concurrent.ThreadLocalRandom;
+
import org.apache.commons.lang.ArrayUtils;
import org.apache.hadoop.io.Writable;
import org.apache.sysds.api.DMLException;
@@ -30,13 +46,8 @@
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
-import java.io.*;
-import java.lang.ref.SoftReference;
-import java.util.*;
-import java.util.concurrent.ThreadLocalRandom;
-
@SuppressWarnings({"rawtypes","unchecked"}) //allow generic native arrays
-public class FrameBlock implements Writable, CacheBlock, Externalizable
+public class FrameBlock implements CacheBlock, Externalizable
{
private static final long serialVersionUID = -3993450030207130665L;
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
index 36eb7dc..ce16369 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
@@ -39,6 +39,7 @@
import org.apache.sysds.runtime.data.SparseBlockCSR;
import org.apache.sysds.runtime.data.SparseBlockFactory;
import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.functionobjects.CM;
import org.apache.sysds.runtime.functionobjects.IndexFunction;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
@@ -51,7 +52,6 @@
import org.apache.sysds.runtime.functionobjects.ReduceDiag;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
-import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
@@ -59,9 +59,9 @@
import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
+import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
-import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
@@ -87,6 +87,8 @@
*/
public class LibMatrixAgg
{
+ // private static final Log LOG = LogFactory.getLog(LibMatrixAgg.class.getName());
+
//internal configuration parameters
private static final boolean NAN_AWARENESS = false;
private static final long PAR_NUMCELL_THRESHOLD1 = 1024*1024; //Min 1M elements
@@ -198,7 +200,7 @@
}
public static void aggregateUnaryMatrix(MatrixBlock in, MatrixBlock out, AggregateUnaryOperator uaop) {
- //prepare meta data
+
AggType aggtype = getAggType(uaop);
final int m = in.rlen;
final int m2 = out.rlen;
@@ -224,8 +226,6 @@
//cleanup output and change representation (if necessary)
out.recomputeNonZeros();
out.examSparsity();
-
- //System.out.println("uagg ("+in.rlen+","+in.clen+","+in.sparse+") in "+time.stop()+"ms.");
}
public static void aggregateUnaryMatrix(MatrixBlock in, MatrixBlock out, AggregateUnaryOperator uaop, int k) {
@@ -254,7 +254,7 @@
out.reset(m2, n2, false); //always dense
out.allocateDenseBlock();
}
-
+
//core multi-threaded unary aggregate computation
//(currently: always parallelization over number of rows)
try {
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
index e96a00a..deba22f 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
@@ -240,7 +240,7 @@
boolean computeMean = (_mvMethodList[i] == MVMethod.GLOBAL_MEAN || _isMVScaled.get(i) );
if(computeMean) {
// global_mean
- double d = UtilFunctions.parseToDouble(w);
+ double d = UtilFunctions.parseToDouble(w, UtilFunctions.defaultNaString);
_meanFn.execute2(_meanList[i], d, _countList[i]);
if (_isMVScaled.get(i) && _mvscMethodList[i] == MVMethod.GLOBAL_MODE)
@@ -263,7 +263,7 @@
{
int colID = _scnomvList[i];
w = UtilFunctions.unquote(words[colID-1].trim());
- double d = UtilFunctions.parseToDouble(w);
+ double d = UtilFunctions.parseToDouble(w, UtilFunctions.defaultNaString);
_scnomvCountList[i]++; // not required, this is always equal to total #records processed
_meanFn.execute2(_scnomvMeanList[i], d, _scnomvCountList[i]);
if(_scnomvMethodList[i] == MVMethod.GLOBAL_MODE)
diff --git a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
index 3f10dd0..cb9ad41 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
@@ -304,8 +304,8 @@
if( map == null )
throw new IOException("Binning map for column '"+name+"' (id="+colID+") not existing.");
String[] fields = map.split(TfUtils.TXMTD_SEP);
- double min = UtilFunctions.parseToDouble(fields[1]);
- double binwidth = UtilFunctions.parseToDouble(fields[3]);
+ double min = UtilFunctions.parseToDouble(fields[1], UtilFunctions.defaultNaString);
+ double binwidth = UtilFunctions.parseToDouble(fields[3], UtilFunctions.defaultNaString);
int nbins = UtilFunctions.parseToInt(fields[4]);
//materialize bins to support equi-width/equi-height
for( int i=0; i<nbins; i++ ) {
diff --git a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java
index 3ac19f2..086408f 100644
--- a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java
@@ -19,6 +19,17 @@
package org.apache.sysds.runtime.util;
+import java.io.IOException;
+import java.text.DecimalFormat;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.BitSet;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map.Entry;
+import java.util.StringTokenizer;
+
import org.apache.commons.lang.StringUtils;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.BlockRealMatrix;
@@ -27,12 +38,15 @@
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.data.BasicTensorBlock;
+import org.apache.sysds.runtime.data.DataTensorBlock;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.DenseBlockFactory;
-import org.apache.sysds.runtime.data.DataTensorBlock;
import org.apache.sysds.runtime.data.SparseBlock;
-import org.apache.sysds.runtime.data.BasicTensorBlock;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.instructions.cp.BooleanObject;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -56,17 +70,6 @@
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.meta.DataCharacteristics;
-import java.io.IOException;
-import java.text.DecimalFormat;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.BitSet;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map.Entry;
-import java.util.StringTokenizer;
-
/**
* This class provides methods to read and write matrix blocks from to HDFS using different data formats.
@@ -1217,6 +1220,47 @@
return sb.toString();
}
+ public static String toString(ListObject list,int rows, int cols, boolean sparse, String separator, String lineSeparator, int rowsToPrint, int colsToPrint, int decimal)
+ {
+ StringBuilder sb = new StringBuilder();
+ sb.append("List containing:\n");
+ sb.append("[");
+ for(Data x : list.getData()){
+ if( x instanceof MatrixObject) {
+ sb.append("\nMatrix:\n");
+ MatrixObject dat = (MatrixObject) x;
+ MatrixBlock matrix = (MatrixBlock) dat.acquireRead();
+ sb.append(DataConverter.toString(matrix, sparse, separator, lineSeparator, rows, cols, decimal));
+ dat.release();
+ }
+ else if( x instanceof TensorObject ) {
+ sb.append("\n");
+ TensorObject dat = (TensorObject) x;
+ TensorBlock tensor = (TensorBlock) dat.acquireRead();
+ sb.append(DataConverter.toString(tensor, sparse, separator,
+ lineSeparator, "[", "]", rows, cols, decimal));
+ dat.release();
+ }
+ else if( x instanceof FrameObject ) {
+ sb.append("\n");
+ FrameObject dat = (FrameObject) x;
+ FrameBlock frame = (FrameBlock) dat.acquireRead();
+ sb.append(DataConverter.toString(frame, sparse, separator, lineSeparator, rows, cols, decimal));
+ dat.release();
+ }
+ else if (x instanceof ListObject){
+ ListObject dat = (ListObject) x;
+ sb.append(DataConverter.toString(dat, cols, rows,sparse, separator, lineSeparator, rows, cols, decimal));
+ }else{
+ sb.append(x.toString());
+ }
+ sb.append(", ");
+ }
+ sb.delete(sb.length() -2, sb.length());
+ sb.append("]");
+ return sb.toString();
+ }
+
public static int[] getTensorDimensions(ExecutionContext ec, CPOperand dims) {
int[] tDims;
switch (dims.getDataType()) {
diff --git a/src/main/java/org/apache/sysds/runtime/util/FastStringTokenizer.java b/src/main/java/org/apache/sysds/runtime/util/FastStringTokenizer.java
index 4521446..8ff78b5 100644
--- a/src/main/java/org/apache/sysds/runtime/util/FastStringTokenizer.java
+++ b/src/main/java/org/apache/sysds/runtime/util/FastStringTokenizer.java
@@ -86,6 +86,6 @@
}
public double nextDouble() {
- return UtilFunctions.parseToDouble(nextToken());
+ return UtilFunctions.parseToDouble(nextToken(),UtilFunctions.defaultNaString);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index 88b7041..ad0b6d7 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -19,6 +19,15 @@
package org.apache.sysds.runtime.util;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.BitSet;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.Future;
+
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.math3.random.RandomDataGenerator;
@@ -36,15 +45,9 @@
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.meta.TensorCharacteristics;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.BitSet;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.Future;
+public class UtilFunctions {
+ // private static final Log LOG = LogFactory.getLog(UtilFunctions.class.getName());
-public class UtilFunctions
-{
//for accurate cast of double values to int and long
//IEEE754: binary64 (double precision) eps = 2^(-53) = 1.11 * 10^(-16)
//(same epsilon as used for matrix index cast in R)
@@ -54,6 +57,12 @@
//because it determines the max hash domain size
public static final long ADD_PRIME1 = 99991;
public static final int DIVIDE_PRIME = 1405695061;
+
+ public static final HashSet<String> defaultNaString = new HashSet<>();
+
+ static{
+ defaultNaString.add("NA");
+ }
public static int intHashCode(int key1, int key2) {
return 31 * (31 + key1) + key2;
@@ -351,11 +360,12 @@
* environments because Double.parseDouble relied on a synchronized cache
* (which was replaced with thread-local caches in JDK8).
*
- * @param str string to parse to double
+ * @param str string to parse to double
+ * @param isNan collection of Nan string which if encountered should be parsed to nan value
* @return double value
*/
- public static double parseToDouble(String str) {
- return "NA".equals(str) ?
+ public static double parseToDouble(String str, Set<String> isNan ) {
+ return isNan.contains(str) ?
Double.NaN :
Double.parseDouble(str);
}
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 4647efd..56845c5 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -27,7 +27,6 @@
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
-import java.io.OutputStream;
import java.io.PrintStream;
import java.net.ServerSocket;
import java.util.ArrayList;
@@ -90,7 +89,7 @@
public abstract class AutomatedTestBase {
private static final Log LOG = LogFactory.getLog(AutomatedTestBase.class.getName());
-
+
public static final boolean EXCEPTION_EXPECTED = true;
public static final boolean EXCEPTION_NOT_EXPECTED = false;
@@ -100,9 +99,9 @@
public static final int FED_WORKER_WAIT = 500; // in ms
- // With OpenJDK 8u242 on Windows, the new changes in JDK are not allowing
- // to set the native library paths internally thus breaking the code.
- // That is why, these static assignments to java.library.path and hadoop.home.dir
+ // With OpenJDK 8u242 on Windows, the new changes in JDK are not allowing
+ // to set the native library paths internally thus breaking the code.
+ // That is why, these static assignments to java.library.path and hadoop.home.dir
// (for native winutils) have been removed.
/**
@@ -126,7 +125,7 @@
protected enum CodegenTestType {
DEFAULT, FUSE_ALL, FUSE_NO_REDUNDANCY;
-
+
public String getCodgenConfig() {
switch(this) {
case DEFAULT:
@@ -135,12 +134,12 @@
return "SystemDS-config-codegen-fuse-all.xml";
case FUSE_NO_REDUNDANCY:
return "SystemDS-config-codegen-fuse-no-redundancy.xml";
- default:
- throw new RuntimeException("Unsupported codegen test config: "+this.name());
+ default:
+ throw new RuntimeException("Unsupported codegen test config: " + this.name());
}
}
}
-
+
/**
* Location under which we create local temporary directories for test cases. To adjust where testTemp is located,
* use -Dsystemds.testTemp.root.dir=<new location>. This is necessary if any parent directories are
@@ -197,11 +196,11 @@
private int iExpectedStdOutState = 0;
private String unexpectedStdOut;
private int iUnexpectedStdOutState = 0;
- private PrintStream originalPrintStreamStd = null;
+ // private PrintStream originalPrintStreamStd = null;
private String expectedStdErr;
private int iExpectedStdErrState = 0;
- private PrintStream originalErrStreamStd = null;
+ // private PrintStream originalErrStreamStd = null;
private boolean outputBuffering = true;
@@ -326,14 +325,16 @@
LOG.info("This test case overrides default configuration with " + tmp.getPath());
return tmp;
}
-
+
protected ExecMode setExecMode(ExecType instType) {
switch(instType) {
- case SPARK: return setExecMode(ExecMode.SPARK);
- default: return setExecMode(ExecMode.HYBRID);
+ case SPARK:
+ return setExecMode(ExecMode.SPARK);
+ default:
+ return setExecMode(ExecMode.HYBRID);
}
}
-
+
protected ExecMode setExecMode(ExecMode execMode) {
ExecMode platformOld = rtplatform;
rtplatform = execMode;
@@ -482,16 +483,17 @@
return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc, null);
}
- protected double [][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR,
- MatrixCharacteristics mc) {
+ protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR,
+ MatrixCharacteristics mc) {
return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc, null);
}
- protected double [][] writeInputMatrixWithMTD(String name, double[][] matrix, PrivacyConstraint privacyConstraint) {
+ protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, PrivacyConstraint privacyConstraint) {
return writeInputMatrixWithMTD(name, matrix, false, null, privacyConstraint);
}
- protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR, PrivacyConstraint privacyConstraint) {
+ protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR,
+ PrivacyConstraint privacyConstraint) {
MatrixCharacteristics mc = new MatrixCharacteristics(matrix.length, matrix[0].length,
OptimizerUtils.DEFAULT_BLOCKSIZE, -1);
return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc, privacyConstraint);
@@ -755,8 +757,8 @@
}
public static String readDMLMetaDataValue(String fileName, String outputDir, String key) throws JSONException {
- JSONObject meta = getMetaDataJSON(fileName, outputDir);
- return meta.get(key).toString();
+ JSONObject meta = getMetaDataJSON(fileName, outputDir);
+ return meta.get(key).toString();
}
public static ValueType readDMLMetaDataValueType(String fileName) {
@@ -1004,11 +1006,12 @@
LOG.info("R is finished (in " + ((double) t1 - t0) / 1000000000 + " sec)");
}
catch(Exception e) {
- if(e.getMessage().contains("ERROR: R has ended irregularly")){
+ if(e.getMessage().contains("ERROR: R has ended irregularly")) {
StringBuilder errorMessage = new StringBuilder();
errorMessage.append(e.getMessage());
fail(errorMessage.toString());
- }else {
+ }
+ else {
e.printStackTrace();
StringBuilder errorMessage = new StringBuilder();
errorMessage.append("failed to run script " + executionFile);
@@ -1093,7 +1096,8 @@
* @param expectedException expected exception
* @param maxMRJobs specifies a maximum limit for the number of MR jobs. If set to -1 there is no limit.
*/
- protected ByteArrayOutputStream runTest(boolean newWay, boolean exceptionExpected, Class<?> expectedException, int maxMRJobs) {
+ protected ByteArrayOutputStream runTest(boolean newWay, boolean exceptionExpected, Class<?> expectedException,
+ int maxMRJobs) {
return runTest(newWay, exceptionExpected, expectedException, null, maxMRJobs);
}
@@ -1105,17 +1109,16 @@
* @param expectedException The expected exception
* @return The Std output from the test.
*/
- protected ByteArrayOutputStream runTest(Class<?> expectedException){
- return runTest( expectedException, -1);
+ protected ByteArrayOutputStream runTest(Class<?> expectedException) {
+ return runTest(expectedException, -1);
}
- protected ByteArrayOutputStream runTest(Class<?> expectedException, int maxSparkInst){
- return runTest( expectedException, null, maxSparkInst);
+ protected ByteArrayOutputStream runTest(Class<?> expectedException, int maxSparkInst) {
+ return runTest(expectedException, null, maxSparkInst);
}
- protected ByteArrayOutputStream runTest(Class<?> expectedException, String errMessage,
- int maxSparkInst){
- return runTest(true, expectedException!= null, expectedException, errMessage, maxSparkInst);
+ protected ByteArrayOutputStream runTest(Class<?> expectedException, String errMessage, int maxSparkInst) {
+ return runTest(true, expectedException != null, expectedException, errMessage, maxSparkInst);
}
/**
@@ -1130,8 +1133,8 @@
* @param errMessage expected error message
* @param maxSparkInst specifies a maximum limit for the number of MR jobs. If set to -1 there is no limit.
*/
- protected ByteArrayOutputStream runTest(boolean newWay, boolean exceptionExpected, Class<?> expectedException, String errMessage,
- int maxSparkInst) {
+ protected ByteArrayOutputStream runTest(boolean newWay, boolean exceptionExpected, Class<?> expectedException,
+ String errMessage, int maxSparkInst) {
String executionFile = sourceDirectory + selectedTest + ".dml";
@@ -1245,8 +1248,8 @@
args.add("-gpu");
}
- protected int getRandomAvailablePort(){
- try (ServerSocket availableSocket = new ServerSocket(0) ) {
+ protected int getRandomAvailablePort() {
+ try(ServerSocket availableSocket = new ServerSocket(0)) {
return availableSocket.getLocalPort();
}
catch(IOException e) {
@@ -1266,13 +1269,13 @@
args.add(fedWorkArgs[i]);
String[] finalArguments = args.toArray(new String[args.size()]);
-
+
try {
t = new Thread(() -> {
try {
DMLScript.main(finalArguments);
}
- catch(IOException e){
+ catch(IOException e) {
}
});
@@ -1286,16 +1289,16 @@
}
private boolean rCompareException(boolean exceptionExpected, String errMessage, Throwable e, boolean result) {
- if (e.getCause() != null) {
+ if(e.getCause() != null) {
result |= rCompareException(exceptionExpected, errMessage, e.getCause(), result);
}
- if (exceptionExpected && errMessage != null && e.getMessage().contains(errMessage)) {
+ if(exceptionExpected && errMessage != null && e.getMessage().contains(errMessage)) {
result = true;
}
return result;
}
- private String getStackTraceString(Throwable e, int level){
+ public static String getStackTraceString(Throwable e, int level) {
StringBuilder sb = new StringBuilder();
sb.append("\nLEVEL : " + level);
sb.append("\nException : " + e.getClass());
@@ -1304,22 +1307,21 @@
if(ste.toString().contains("org.junit")) {
sb.append("\n > ... Stopping Stack Trace at JUnit");
break;
- }else{
- sb.append("\n"+ level+" > " + ste);
+ }
+ else {
+ sb.append("\n" + level + " > " + ste);
}
}
- if(e.getCause() == null){
+ if(e.getCause() == null) {
return sb.toString();
}
- sb.append(getStackTraceString(e.getCause(), level +1));
+ sb.append(getStackTraceString(e.getCause(), level + 1));
return sb.toString();
}
- public void cleanupScratchSpace()
- {
- try
- {
- //parse config file
+ public void cleanupScratchSpace() {
+ try {
+ // parse config file
DMLConfig conf = new DMLConfig(getCurConfigFile().getPath());
// delete the scratch_space and all contents
@@ -1327,17 +1329,15 @@
String dir = conf.getTextValue(DMLConfig.SCRATCH_SPACE);
HDFSTool.deleteFileIfExistOnHDFS(dir);
}
- catch (Exception ex)
- {
- //ex.printStackTrace();
- return; //no effect on tests
+ catch(Exception ex) {
+ // ex.printStackTrace();
+ return; // no effect on tests
}
}
/**
* <p>
- * Checks if a process-local temporary directory exists
- * in the current working directory.
+ * Checks if a process-local temporary directory exists in the current working directory.
* </p>
*
* @return true if a process-local temp directory is present.
@@ -1354,7 +1354,8 @@
String pLocalDir = sb.toString();
return HDFSTool.existsFileOnHDFS(pLocalDir);
- } catch (Exception ex) {
+ }
+ catch(Exception ex) {
ex.printStackTrace();
return true;
}
@@ -1371,15 +1372,13 @@
/**
* <p>
- * Compares the results of the computation with the expected ones with a
- * specified tolerance.
+ * Compares the results of the computation with the expected ones with a specified tolerance.
* </p>
*
- * @param epsilon
- * tolerance
+ * @param epsilon tolerance
*/
protected void compareResultsWithR(double epsilon) {
- for (int i = 0; i < comparisonFiles.length; i++) {
+ for(int i = 0; i < comparisonFiles.length; i++) {
TestUtils.compareDMLHDFSFileWithRFile(comparisonFiles[i], outputDirectories[i], epsilon);
}
}
@@ -1393,32 +1392,31 @@
compareResultsWithR(0);
}
- protected void compareResultsWithMM () {
- TestUtils.compareMMMatrixWithJavaMatrix (comparisonFiles[0], outputDirectories[0], 0);
+ protected void compareResultsWithMM() {
+ TestUtils.compareMMMatrixWithJavaMatrix(comparisonFiles[0], outputDirectories[0], 0);
}
+
/**
* <p>
- * Compares the results of the computation with the expected ones with a
- * specified tolerance.
+ * Compares the results of the computation with the expected ones with a specified tolerance.
* </p>
*
- * @param epsilon
- * tolerance
+ * @param epsilon tolerance
*/
protected void compareResults(double epsilon) {
- for (int i = 0; i < comparisonFiles.length; i++) {
+ for(int i = 0; i < comparisonFiles.length; i++) {
/* Note that DML scripts may generate a file with only scalar value */
- if (outputDirectories[i].endsWith(".scalar")) {
- String javaFile = comparisonFiles[i].replace(".scalar", "");
- String dmlFile = outputDirectories[i].replace(".scalar", "");
- TestUtils.compareDMLScalarWithJavaScalar(javaFile, dmlFile, epsilon);
+ if(outputDirectories[i].endsWith(".scalar")) {
+ String javaFile = comparisonFiles[i].replace(".scalar", "");
+ String dmlFile = outputDirectories[i].replace(".scalar", "");
+ TestUtils.compareDMLScalarWithJavaScalar(javaFile, dmlFile, epsilon);
}
else {
TestUtils.compareDMLMatrixWithJavaMatrix(comparisonFiles[i], outputDirectories[i], epsilon);
}
}
}
-
+
/**
* <p>
* Compares the results of the computation of the frame with the expected ones.
@@ -1427,58 +1425,54 @@
* @param schema the frame schema
*/
protected void compareResults(ValueType[] schema) {
- for (int i = 0; i < comparisonFiles.length; i++) {
+ for(int i = 0; i < comparisonFiles.length; i++) {
TestUtils.compareDMLFrameWithJavaFrame(schema, comparisonFiles[i], outputDirectories[i]);
}
}
-
/**
* Compare results of the computation with the expected results where rows may be permuted.
+ *
* @param epsilon
*/
- protected void compareResultsRowsOutOfOrder(double epsilon)
- {
- for (int i = 0; i < comparisonFiles.length; i++) {
+ protected void compareResultsRowsOutOfOrder(double epsilon) {
+ for(int i = 0; i < comparisonFiles.length; i++) {
/* Note that DML scripts may generate a file with only scalar value */
- if (outputDirectories[i].endsWith(".scalar")) {
- String javaFile = comparisonFiles[i].replace(".scalar", "");
- String dmlFile = outputDirectories[i].replace(".scalar", "");
- TestUtils.compareDMLScalarWithJavaScalar(javaFile, dmlFile, epsilon);
+ if(outputDirectories[i].endsWith(".scalar")) {
+ String javaFile = comparisonFiles[i].replace(".scalar", "");
+ String dmlFile = outputDirectories[i].replace(".scalar", "");
+ TestUtils.compareDMLScalarWithJavaScalar(javaFile, dmlFile, epsilon);
}
else {
- TestUtils.compareDMLMatrixWithJavaMatrixRowsOutOfOrder(comparisonFiles[i], outputDirectories[i], epsilon);
+ TestUtils
+ .compareDMLMatrixWithJavaMatrixRowsOutOfOrder(comparisonFiles[i], outputDirectories[i], epsilon);
}
}
}
/**
- * Checks that the number of Spark instructions that the current test case has
- * compiled is equal to the expected number. Generates a JUnit error message
- * if the number is out of line.
+ * Checks that the number of Spark instructions that the current test case has compiled is equal to the expected
+ * number. Generates a JUnit error message if the number is out of line.
*
- * @param expectedNumCompiled
- * number of Spark instructions that the current test case is
- * expected to compile
+ * @param expectedNumCompiled number of Spark instructions that the current test case is expected to compile
*/
protected void checkNumCompiledSparkInst(int expectedNumCompiled) {
assertEquals("Unexpected number of compiled Spark instructions.",
- expectedNumCompiled, Statistics.getNoOfCompiledSPInst());
+ expectedNumCompiled,
+ Statistics.getNoOfCompiledSPInst());
}
/**
- * Checks that the number of Spark instructions that the current test case has
- * executed (as opposed to compiling into the execution plan) is equal to
- * the expected number. Generates a JUnit error message if the number is out
- * of line.
+ * Checks that the number of Spark instructions that the current test case has executed (as opposed to compiling
+ * into the execution plan) is equal to the expected number. Generates a JUnit error message if the number is out of
+ * line.
*
- * @param expectedNumExecuted
- * number of Spark instructions that the current test case is
- * expected to run
+ * @param expectedNumExecuted number of Spark instructions that the current test case is expected to run
*/
protected void checkNumExecutedSparkInst(int expectedNumExecuted) {
assertEquals("Unexpected number of executed Spark instructions.",
- expectedNumExecuted, Statistics.getNoOfExecutedSPInst());
+ expectedNumExecuted,
+ Statistics.getNoOfExecutedSPInst());
}
/**
@@ -1486,17 +1480,13 @@
* Checks the results of a computation against a number of characteristics.
* </p>
*
- * @param rows
- * number of rows
- * @param cols
- * number of columns
- * @param min
- * minimum value
- * @param max
- * maximum value
+ * @param rows number of rows
+ * @param cols number of columns
+ * @param min minimum value
+ * @param max maximum value
*/
protected void checkResults(long rows, long cols, double min, double max) {
- for (int i = 0; i < outputDirectories.length; i++) {
+ for(int i = 0; i < outputDirectories.length; i++) {
TestUtils.checkMatrix(outputDirectories[i], rows, cols, min, max);
}
}
@@ -1507,170 +1497,59 @@
* </p>
*/
protected void checkForResultExistence() {
- for (int i = 0; i < outputDirectories.length; i++) {
+ for(int i = 0; i < outputDirectories.length; i++) {
TestUtils.checkForOutputExistence(outputDirectories[i]);
}
}
@After
public void tearDown() {
- if( LOG.isTraceEnabled() )
- LOG.trace("Duration: " + (System.currentTimeMillis() - lTimeBeforeTest) + "ms");
+ LOG.trace("Duration: " + (System.currentTimeMillis() - lTimeBeforeTest) + "ms");
- assertTrue("expected String did not occur: " + expectedStdOut, iExpectedStdOutState == 0
- || iExpectedStdOutState == 2);
- assertTrue("expected String did not occur (stderr): " + expectedStdErr, iExpectedStdErrState == 0
- || iExpectedStdErrState == 2);
+ assertTrue("expected String did not occur: " + expectedStdOut,
+ iExpectedStdOutState == 0 || iExpectedStdOutState == 2);
+ assertTrue("expected String did not occur (stderr): " + expectedStdErr,
+ iExpectedStdErrState == 0 || iExpectedStdErrState == 2);
assertFalse("unexpected String occurred: " + unexpectedStdOut, iUnexpectedStdOutState == 1);
TestUtils.displayAssertionBuffer();
-
- if (!isOutAndExpectedDeletionDisabled()) {
+ if(!isOutAndExpectedDeletionDisabled()) {
TestUtils.removeHDFSDirectories(inputDirectories.toArray(new String[inputDirectories.size()]));
TestUtils.removeFiles(inputRFiles.toArray(new String[inputRFiles.size()]));
// The following cleanup code is disabled (see [SYSTEMML-256]) until we can figure out
// what test cases are creating temporary directories at the root of the project.
- //TestUtils.removeTemporaryFiles();
+ // TestUtils.removeTemporaryFiles();
TestUtils.clearDirectory(baseDirectory + OUTPUT_DIR);
TestUtils.removeHDFSFiles(expectedFiles.toArray(new String[expectedFiles.size()]));
TestUtils.clearDirectory(baseDirectory + EXPECTED_DIR);
- TestUtils.removeFiles(new String[] { sourceDirectory + selectedTest + ".dmlt" });
- TestUtils.removeFiles(new String[] { sourceDirectory + selectedTest + ".Rt" });
+ TestUtils.removeFiles(new String[] {sourceDirectory + selectedTest + ".dmlt"});
+ TestUtils.removeFiles(new String[] {sourceDirectory + selectedTest + ".Rt"});
}
TestUtils.clearAssertionInformation();
}
-
+
public boolean bufferContainsString(ByteArrayOutputStream buffer, String str){
return Arrays.stream(buffer.toString().split("\n")).anyMatch(x -> x.contains(str));
}
/**
- * Disables the deletion of files and directories in the output and expected
- * folder for this test.
+ * Disables the deletion of files and directories in the output and expected folder for this test.
*/
public void disableOutAndExpectedDeletion() {
setOutAndExpectedDeletionDisabled(true);
}
/**
- * Enables detection of expected output of a line in standard output stream.
- *
- * @param expectedLine
- */
- public void setExpectedStdOut(String expectedLine) {
- this.expectedStdOut = expectedLine;
- originalPrintStreamStd = System.out;
- iExpectedStdOutState = 1;
- System.setOut(new PrintStream(new ExpectedOutputStream()));
- }
-
- /**
- * This class is used to compare the standard output stream against an
- * expected string.
- *
- */
- class ExpectedOutputStream extends OutputStream {
- private String line = "";
-
- @Override
- public void write(int b) throws IOException {
- line += String.valueOf((char) b);
- if (((char) b) == '\n') {
- /** new line */
- if (line.contains(expectedStdOut)) {
- iExpectedStdOutState = 2;
- System.setOut(originalPrintStreamStd);
- } else {
- // Reset buffer
- line = "";
- }
- }
- originalPrintStreamStd.write(b);
- }
- }
-
- public void setExpectedStdErr(String expectedLine) {
- this.expectedStdErr = expectedLine;
- originalErrStreamStd = System.err;
- iExpectedStdErrState = 1;
- System.setErr(new PrintStream(new ExpectedErrorStream()));
- }
-
- /**
- * This class is used to compare the standard error stream against an
- * expected string.
- *
- */
- class ExpectedErrorStream extends OutputStream {
- private String line = "";
-
- @Override
- public void write(int b) throws IOException {
- line += String.valueOf((char) b);
- if (((char) b) == '\n') {
- /** new line */
- if (line.contains(expectedStdErr)) {
- iExpectedStdErrState = 2;
- System.setErr(originalErrStreamStd);
- } else {
- // Reset buffer
- line = "";
- }
- }
- originalErrStreamStd.write(b);
- }
- }
-
- /**
- * Enables detection of unexpected output of a line in standard output stream.
- *
- * @param unexpectedLine String that should not occur in stdout.
- */
- public void setUnexpectedStdOut(String unexpectedLine) {
- this.unexpectedStdOut = unexpectedLine;
- originalPrintStreamStd = System.out;
- System.setOut(new PrintStream(new UnexpectedOutputStream()));
- }
-
- public void setOutputBuffering(boolean flag) {
- outputBuffering = flag;
- }
-
- /**
- * This class is used to compare the standard output stream against
- * an unexpected string.
- */
- class UnexpectedOutputStream extends OutputStream {
- private String line = "";
-
- @Override
- public void write(int b) throws IOException {
- line += String.valueOf((char) b);
- if (((char) b) == '\n') {
- /** new line */
- if (line.contains(unexpectedStdOut)) {
- iUnexpectedStdOutState = 1; // error!
- } else {
- line = ""; // reset buffer
- }
- }
- originalPrintStreamStd.write(b);
- }
- }
-
- /**
* <p>
* Generates a matrix containing easy to debug values in its cells.
* </p>
*
* @param rows
* @param cols
- * @param bContainsZeros
- * If true, the matrix contains zeros. If false, the matrix
- * contains only positive values.
+ * @param bContainsZeros If true, the matrix contains zeros. If false, the matrix contains only positive values.
* @return
*/
protected double[][] createNonRandomMatrixValues(int rows, int cols, boolean bContainsZeros) {
@@ -1679,8 +1558,7 @@
/**
* <p>
- * Generates a matrix containing easy to debug values in its cells. The
- * generated matrix contains zero values
+ * Generates a matrix containing easy to debug values in its cells. The generated matrix contains zero values
* </p>
*
* @param rows
@@ -1698,14 +1576,16 @@
return isOutAndExpectedDeletionDisabled;
}
+ public void setOutputBuffering(boolean flag) {
+ outputBuffering = flag;
+ }
+
/**
* Call this method from a subclass's setUp() method.
- * @param isOutAndExpectedDeletionDisabled
- * TRUE to disable code that deletes temporary files for this
- * test case
+ *
+ * @param isOutAndExpectedDeletionDisabled TRUE to disable code that deletes temporary files for this test case
*/
- protected void setOutAndExpectedDeletionDisabled(
- boolean isOutAndExpectedDeletionDisabled) {
+ protected void setOutAndExpectedDeletionDisabled(boolean isOutAndExpectedDeletionDisabled) {
this.isOutAndExpectedDeletionDisabled = isOutAndExpectedDeletionDisabled;
}
@@ -1743,11 +1623,11 @@
return sourceDirectory + selectedTest + ".R";
}
- protected String getRCmd(String ... args) {
+ protected String getRCmd(String... args) {
StringBuilder sb = new StringBuilder();
sb.append("Rscript ");
sb.append(getRScript());
- for (String arg : args) {
+ for(String arg : args) {
sb.append(" ");
sb.append(arg);
}
@@ -1755,20 +1635,20 @@
}
private boolean isTargetTestDirectory(String path) {
- return (path != null && path.contains(getClass().getSimpleName()));
+ return(path != null && path.contains(getClass().getSimpleName()));
}
private void setCacheDirectory(String directory) {
cacheDir = (directory != null) ? directory : "";
- if (cacheDir.length() > 0 && !cacheDir.endsWith("/")) {
+ if(cacheDir.length() > 0 && !cacheDir.endsWith("/")) {
cacheDir += "/";
}
}
private static String getSourceDirectory(String testDirectory) {
String sourceDirectory = "";
- if (null != testDirectory) {
- if (testDirectory.endsWith("/"))
+ if(null != testDirectory) {
+ if(testDirectory.endsWith("/"))
testDirectory = testDirectory.substring(0, testDirectory.length() - "/".length());
sourceDirectory = testDirectory.substring(0, testDirectory.lastIndexOf("/") + "/".length());
}
@@ -1780,53 +1660,53 @@
* Adds a frame to the input path and writes it to a file.
* </p>
*
- * @param name
- * directory name
- * @param data
- * two dimensional frame data
- * @param bIncludeR
- * generates also the corresponding R frame data
+ * @param name directory name
+ * @param data two dimensional frame data
+ * @param bIncludeR generates also the corresponding R frame data
* @throws IOException
*/
- protected double[][] writeInputFrame(String name, double[][] data, boolean bIncludeR, ValueType[] schema, FileFormat fmt) throws IOException {
+ protected double[][] writeInputFrame(String name, double[][] data, boolean bIncludeR, ValueType[] schema,
+ FileFormat fmt) throws IOException {
String completePath = baseDirectory + INPUT_DIR + name;
String completeRPath = baseDirectory + INPUT_DIR + name + ".csv";
try {
cleanupExistingData(baseDirectory + INPUT_DIR + name, bIncludeR);
- } catch (IOException e) {
+ }
+ catch(IOException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
TestUtils.writeTestFrame(completePath, data, schema, fmt);
- if (bIncludeR) {
+ if(bIncludeR) {
TestUtils.writeTestFrame(completeRPath, data, schema, FileFormat.CSV, true);
inputRFiles.add(completeRPath);
}
- if (DEBUG)
+ if(DEBUG)
TestUtils.writeTestFrame(DEBUG_TEMP_DIR + completePath, data, schema, fmt);
inputDirectories.add(baseDirectory + INPUT_DIR + name);
return data;
}
- protected double[][] writeInputFrameWithMTD(String name, double[][] data, boolean bIncludeR, ValueType[] schema, FileFormat fmt) throws IOException {
- MatrixCharacteristics mc = new MatrixCharacteristics(data.length, data[0].length, OptimizerUtils.DEFAULT_BLOCKSIZE, -1);
+ protected double[][] writeInputFrameWithMTD(String name, double[][] data, boolean bIncludeR, ValueType[] schema,
+ FileFormat fmt) throws IOException {
+ MatrixCharacteristics mc = new MatrixCharacteristics(data.length, data[0].length,
+ OptimizerUtils.DEFAULT_BLOCKSIZE, -1);
return writeInputFrameWithMTD(name, data, bIncludeR, mc, schema, fmt);
}
- protected double[][] writeInputFrameWithMTD(String name, double[][] data, boolean bIncludeR, MatrixCharacteristics mc, ValueType[] schema, FileFormat fmt) throws IOException {
+ protected double[][] writeInputFrameWithMTD(String name, double[][] data, boolean bIncludeR,
+ MatrixCharacteristics mc, ValueType[] schema, FileFormat fmt) throws IOException {
writeInputFrame(name, data, bIncludeR, schema, fmt);
// write metadata file
- try
- {
+ try {
String completeMTDPath = baseDirectory + INPUT_DIR + name + ".mtd";
HDFSTool.writeMetaDataFile(completeMTDPath, null, schema, DataType.FRAME, mc, fmt);
}
- catch(IOException e)
- {
+ catch(IOException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
@@ -1839,60 +1719,58 @@
* Adds a frame to the input path and writes it to a file.
* </p>
*
- * @param name
- * directory name
- * @param data
- * two dimensional frame data
+ * @param name directory name
+ * @param data two dimensional frame data
* @param schema
* @param oi
* @throws IOException
*/
protected double[][] writeInputFrame(String name, double[][] data, ValueType[] schema, FileFormat fmt)
- throws IOException
- {
+ throws IOException {
return writeInputFrame(name, data, false, schema, fmt);
}
protected boolean heavyHittersContainsString(String... str) {
- for( String opcode : Statistics.getCPHeavyHitterOpCodes())
- for( String s : str )
+ for(String opcode : Statistics.getCPHeavyHitterOpCodes())
+ for(String s : str)
if(opcode.equals(s))
return true;
return false;
}
-
+
protected boolean heavyHittersContainsString(String str, int minCount) {
int count = 0;
- for( String opcode : Statistics.getCPHeavyHitterOpCodes())
+ for(String opcode : Statistics.getCPHeavyHitterOpCodes())
count += opcode.equals(str) ? 1 : 0;
- return (count >= minCount);
+ return(count >= minCount);
}
-
+
protected boolean heavyHittersContainsSubString(String... str) {
- for( String opcode : Statistics.getCPHeavyHitterOpCodes())
- for( String s : str )
+ for(String opcode : Statistics.getCPHeavyHitterOpCodes())
+ for(String s : str)
if(opcode.contains(s))
return true;
return false;
}
-
+
protected boolean heavyHittersContainsSubString(String str, int minCount) {
int count = 0;
- for( String opcode : Statistics.getCPHeavyHitterOpCodes())
+ for(String opcode : Statistics.getCPHeavyHitterOpCodes())
count += opcode.contains(str) ? 1 : 0;
- return (count >= minCount);
+ return(count >= minCount);
}
- protected boolean checkedPrivacyConstraintsContains(PrivacyLevel... levels){
- for ( PrivacyLevel level : levels)
- if (!(CheckedConstraintsLog.getCheckedConstraints().containsKey(level)))
+ protected boolean checkedPrivacyConstraintsContains(PrivacyLevel... levels) {
+ for(PrivacyLevel level : levels)
+ if(!(CheckedConstraintsLog.getCheckedConstraints().containsKey(level)))
return false;
return true;
}
- protected boolean checkedPrivacyConstraintsAbove(Map<PrivacyLevel,Long> levelCounts){
- for ( Map.Entry<PrivacyLevel,Long> levelCount : levelCounts.entrySet()){
- if (!(CheckedConstraintsLog.getCheckedConstraints().get(levelCount.getKey()).longValue() >= levelCount.getValue()))
+ protected boolean checkedPrivacyConstraintsAbove(Map<PrivacyLevel, Long> levelCounts) {
+ for(Map.Entry<PrivacyLevel, Long> levelCount : levelCounts.entrySet()) {
+ if(!(CheckedConstraintsLog.getCheckedConstraints().get(levelCount.getKey()).longValue() >= levelCount
+ .getValue()))
return false;
}
return true;
@@ -1902,19 +1780,19 @@
* Create a SystemDS-preferred Spark Session.
*
* @param appName the application name
- * @param master the master value (ie, "local", etc)
+ * @param master the master value (ie, "local", etc)
* @return Spark Session
*/
public static SparkSession createSystemDSSparkSession(String appName, String master) {
Builder builder = SparkSession.builder();
- if (appName != null) {
+ if(appName != null) {
builder.appName(appName);
}
- if (master != null) {
+ if(master != null) {
builder.master(master);
}
builder.config("spark.driver.maxResultSize", "0");
- if (SparkExecutionContext.FAIR_SCHEDULER_MODE) {
+ if(SparkExecutionContext.FAIR_SCHEDULER_MODE) {
builder.config("spark.scheduler.mode", "FAIR");
}
builder.config("spark.locality.wait", "5s");
@@ -1925,7 +1803,8 @@
public static String getMatrixAsString(double[][] matrix) {
try {
return DataConverter.toString(DataConverter.convertToMatrixBlock(matrix));
- } catch (DMLRuntimeException e) {
+ }
+ catch(DMLRuntimeException e) {
return "N/A";
}
}
diff --git a/src/test/java/org/apache/sysds/test/applications/NNTest.java b/src/test/java/org/apache/sysds/test/applications/NNTest.java
index 17fcb86..d03b768 100644
--- a/src/test/java/org/apache/sysds/test/applications/NNTest.java
+++ b/src/test/java/org/apache/sysds/test/applications/NNTest.java
@@ -20,6 +20,7 @@
package org.apache.sysds.test.applications;
import static org.apache.sysds.api.mlcontext.ScriptFactory.dmlFromFile;
+import static org.junit.Assert.assertTrue;
import org.junit.Test;
import org.apache.sysds.api.mlcontext.Script;
@@ -36,7 +37,7 @@
@Test
public void testNNLibrary() {
Script script = dmlFromFile(TEST_SCRIPT);
- setUnexpectedStdOut(ERROR_STRING);
- ml.execute(script);
+ String stdOut = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(stdOut, !stdOut.contains(ERROR_STRING));
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/data/misc/TimeTest.java b/src/test/java/org/apache/sysds/test/functions/data/misc/TimeTest.java
index 843ccc0..927c8ff 100644
--- a/src/test/java/org/apache/sysds/test/functions/data/misc/TimeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/data/misc/TimeTest.java
@@ -20,6 +20,9 @@
package org.apache.sysds.test.functions.data.misc;
import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.test.AutomatedTestBase;
@@ -56,9 +59,10 @@
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
//programArgs = new String[]{"-explain", "hops", "-stats", "2", "-args", output("B") };
- programArgs = new String[]{"-explain", "-args", output("B") };
+ programArgs = new String[]{"-args", output("B") };
- runTest(true, false, null, -1);
+ String out = runTest(null).toString();
+ assertTrue("stdout:" + out, out.contains("time diff : "));
}
finally {
rtplatform = platformOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameConverterTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameConverterTest.java
index 8e098a6..3d591ef 100644
--- a/src/test/java/org/apache/sysds/test/functions/frame/FrameConverterTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameConverterTest.java
@@ -408,7 +408,7 @@
JavaPairRDD<LongWritable,Text> rddIn = (JavaPairRDD<LongWritable,Text>) sc
.hadoopFile(fnameIn, iinfo.inputFormatClass, iinfo.keyClass, iinfo.valueClass);
JavaPairRDD<LongWritable, FrameBlock> rddOut = FrameRDDConverterUtils
- .csvToBinaryBlock(sc, rddIn, mc, null, false, separator, false, 0)
+ .csvToBinaryBlock(sc, rddIn, mc, null, false, separator, false, 0, UtilFunctions.defaultNaString)
.mapToPair(new LongFrameToLongWritableFrameFunction());
rddOut.saveAsHadoopFile(fnameOut, LongWritable.class, FrameBlock.class, oinfo.outputFormatClass);
break;
diff --git a/src/test/java/org/apache/sysds/test/functions/io/csv/FormatChangeTest.java b/src/test/java/org/apache/sysds/test/functions/io/csv/FormatChangeTest.java
index 604b134..381a96c 100644
--- a/src/test/java/org/apache/sysds/test/functions/io/csv/FormatChangeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/io/csv/FormatChangeTest.java
@@ -88,7 +88,7 @@
runTest(true, false, null, -1);
// Test TextCell -> CSV conversion
- LOG.info("TextCell -> CSV");
+ LOG.debug("TextCell -> CSV");
programArgs[2] = "text";
programArgs[3] = csvFile;
programArgs[4] = "csv";
@@ -97,7 +97,7 @@
compareFiles(rows, cols, sparsity, txtFile, "text", csvFile);
// Test BinaryBlock -> CSV conversion
- LOG.info("BinaryBlock -> CSV");
+ LOG.debug("BinaryBlock -> CSV");
programArgs = oldProgramArgs;
programArgs[1] = binFile;
programArgs[2] = "binary";
@@ -108,7 +108,7 @@
compareFiles(rows, cols, sparsity, binFile, "binary", csvFile);
// Test CSV -> TextCell conversion
- LOG.info("CSV -> TextCell");
+ LOG.debug("CSV -> TextCell");
programArgs = oldProgramArgs;
programArgs[1] = csvFile;
programArgs[2] = "csv";
@@ -119,7 +119,7 @@
compareFiles(rows, cols, sparsity, txtFile, "text", csvFile);
// Test CSV -> BinaryBlock conversion
- LOG.info("CSV -> BinaryBlock");
+ LOG.debug("CSV -> BinaryBlock");
programArgs = oldProgramArgs;
programArgs[1] = csvFile;
programArgs[2] = "csv";
diff --git a/src/test/java/org/apache/sysds/test/functions/io/csv/ReadCSVTest.java b/src/test/java/org/apache/sysds/test/functions/io/csv/ReadCSVTest.java
index ba2e93e..5476430 100644
--- a/src/test/java/org/apache/sysds/test/functions/io/csv/ReadCSVTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/io/csv/ReadCSVTest.java
@@ -26,46 +26,39 @@
import org.apache.sysds.test.TestUtils;
import org.junit.Test;
-/**
- * JUnit Test cases to evaluate the functionality of reading CSV files.
- *
- * Test 1: read() with a mtd file.
- *
- * Test 2: read(format="csv") without mtd file.
- *
- * Test 3: read() with complete mtd file.
- *
- */
public abstract class ReadCSVTest extends CSVTestBase {
protected abstract int getId();
-
- @Test
- public void testCSV_Sequential_CP1() {
- runCSVTest(getId(), ExecMode.SINGLE_NODE, false);
+ protected String getInputCSVFileName() {
+ return "transfusion_" + getId();
}
- @Test
- public void testCSV_Parallel_CP1() {
- runCSVTest(getId(), ExecMode.SINGLE_NODE, true);
- }
+ @Test
+ public void testCSV_Sequential_CP1() {
+ runCSVTest(getId(), ExecMode.SINGLE_NODE, false);
+ }
- @Test
- public void testCSV_Sequential_CP() {
- runCSVTest(getId(), ExecMode.HYBRID, false);
- }
+ @Test
+ public void testCSV_Parallel_CP1() {
+ runCSVTest(getId(), ExecMode.SINGLE_NODE, true);
+ }
- @Test
- public void testCSV_Parallel_CP() {
- runCSVTest(getId(), ExecMode.HYBRID, true);
- }
+ @Test
+ public void testCSV_Sequential_CP() {
+ runCSVTest(getId(), ExecMode.HYBRID, false);
+ }
+
+ @Test
+ public void testCSV_Parallel_CP() {
+ runCSVTest(getId(), ExecMode.HYBRID, true);
+ }
@Test
public void testCSV_SP() {
runCSVTest(getId(), ExecMode.SPARK, false);
}
- protected void runCSVTest(int testNumber, ExecMode platform, boolean parallel) {
+ protected String runCSVTest(int testNumber, ExecMode platform, boolean parallel) {
ExecMode oldPlatform = rtplatform;
rtplatform = platform;
@@ -75,6 +68,7 @@
boolean oldpar = CompilerConfig.FLAG_PARREADWRITE_TEXT;
+ String output;
try {
CompilerConfig.FLAG_PARREADWRITE_TEXT = parallel;
@@ -83,7 +77,7 @@
loadTestConfiguration(config);
String HOME = SCRIPT_DIR + TEST_DIR;
- String inputMatrixNameNoExtension = HOME + INPUT_DIR + "transfusion_" + testNumber;
+ String inputMatrixNameNoExtension = HOME + INPUT_DIR + getInputCSVFileName();
String inputMatrixNameWithExtension = inputMatrixNameNoExtension + ".csv";
String dmlOutput = output("dml.scalar");
String rOutput = output("R.scalar");
@@ -94,7 +88,7 @@
fullRScriptName = HOME + "csv_verify2.R";
rCmd = "Rscript" + " " + fullRScriptName + " " + inputMatrixNameNoExtension + ".single.csv " + rOutput;
- runTest(true, false, null, -1);
+ output = runTest(true, false, null, -1).toString();
runRScript(true);
double dmlScalar = TestUtils.readDMLScalar(dmlOutput);
@@ -107,5 +101,6 @@
CompilerConfig.FLAG_PARREADWRITE_TEXT = oldpar;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
+ return output;
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/io/csv/ReadCSVTest4Nan.java b/src/test/java/org/apache/sysds/test/functions/io/csv/ReadCSVTest4Nan.java
new file mode 100644
index 0000000..cde0182
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/io/csv/ReadCSVTest4Nan.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.io.csv;
+
+import static org.junit.Assert.assertTrue;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.conf.CompilerConfig;
+import org.apache.sysds.test.TestConfiguration;
+
+public class ReadCSVTest4Nan extends ReadCSVTest {
+
+ private final static String TEST_NAME = "ReadCSVTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + ReadCSVTest4Nan.class.getSimpleName() + "/";
+
+ @Override
+ protected int getId() {
+ return 4;
+ }
+
+ @Override
+ protected String getTestClassDir() {
+ return TEST_CLASS_DIR;
+ }
+
+ @Override
+ protected String getTestName() {
+ return TEST_NAME;
+ }
+
+ @Override
+ protected String getInputCSVFileName() {
+ return "nan_integers_" + getId();
+ }
+
+ @Override
+ protected String runCSVTest(int testNumber, ExecMode platform, boolean parallel) {
+ ExecMode oldPlatform = rtplatform;
+ rtplatform = platform;
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if(rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ boolean oldpar = CompilerConfig.FLAG_PARREADWRITE_TEXT;
+ String output;
+ try {
+ CompilerConfig.FLAG_PARREADWRITE_TEXT = parallel;
+
+ TestConfiguration config = getTestConfiguration(getTestName());
+
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ String inputMatrixNameNoExtension = HOME + INPUT_DIR + getInputCSVFileName();
+ String inputMatrixNameWithExtension = inputMatrixNameNoExtension + ".csv";
+ String dmlOutput = output("dml.scalar");
+
+ fullDMLScriptName = HOME + getTestName() + "_" + testNumber + ".dml";
+ programArgs = new String[] {"-args", inputMatrixNameWithExtension, dmlOutput};
+
+ output = runTest(true, false, null, -1).toString();
+ assertTrue(output.contains("NaN"));
+ }
+ finally {
+ rtplatform = oldPlatform;
+ CompilerConfig.FLAG_PARREADWRITE_TEXT = oldpar;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ return output;
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/org/apache/sysds/test/functions/io/csv/ReadCSVTest5Nan.java b/src/test/java/org/apache/sysds/test/functions/io/csv/ReadCSVTest5Nan.java
new file mode 100644
index 0000000..a1af5a5
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/io/csv/ReadCSVTest5Nan.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.io.csv;
+
+import static org.junit.Assert.assertTrue;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.conf.CompilerConfig;
+import org.apache.sysds.test.TestConfiguration;
+
+public class ReadCSVTest5Nan extends ReadCSVTest4Nan {
+
+ private final static String TEST_NAME = "ReadCSVTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + ReadCSVTest5Nan.class.getSimpleName() + "/";
+
+ @Override
+ protected int getId() {
+ return 5;
+ }
+
+ @Override
+ protected String getTestClassDir() {
+ return TEST_CLASS_DIR;
+ }
+
+ @Override
+ protected String getTestName() {
+ return TEST_NAME;
+ }
+
+ @Override
+ protected String runCSVTest(int testNumber, ExecMode platform, boolean parallel) {
+ ExecMode oldPlatform = rtplatform;
+ rtplatform = platform;
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if(rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ boolean oldpar = CompilerConfig.FLAG_PARREADWRITE_TEXT;
+
+ String output;
+ try {
+ CompilerConfig.FLAG_PARREADWRITE_TEXT = parallel;
+
+ TestConfiguration config = getTestConfiguration(getTestName());
+
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ String inputMatrixNameNoExtension = HOME + INPUT_DIR + getInputCSVFileName();
+ String inputMatrixNameWithExtension = inputMatrixNameNoExtension + ".csv";
+ String dmlOutput = output("dml.scalar");
+
+ fullDMLScriptName = HOME + getTestName() + "_" + testNumber + ".dml";
+ programArgs = new String[] {"-args", inputMatrixNameWithExtension, dmlOutput};
+
+ output = runTest(true, false, null, -1).toString();
+ String expected = "NaN 8.000 NaN NaN";
+ assertTrue("\nout: " + output + "\n expected: " + expected, output.contains(expected));
+
+ }
+ finally {
+ rtplatform = oldPlatform;
+ CompilerConfig.FLAG_PARREADWRITE_TEXT = oldpar;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ return output;
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/org/apache/sysds/test/functions/io/csv/ReadCSVTest6Nan.java b/src/test/java/org/apache/sysds/test/functions/io/csv/ReadCSVTest6Nan.java
new file mode 100644
index 0000000..47e3b7e
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/io/csv/ReadCSVTest6Nan.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.io.csv;
+
+public class ReadCSVTest6Nan extends ReadCSVTest5Nan {
+
+ private final static String TEST_NAME = "ReadCSVTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + ReadCSVTest6Nan.class.getSimpleName() + "/";
+
+ @Override
+ protected int getId() {
+ return 6;
+ }
+
+ @Override
+ protected String getTestClassDir() {
+ return TEST_CLASS_DIR;
+ }
+
+ @Override
+ protected String getTestName() {
+ return TEST_NAME;
+ }
+
+}
\ No newline at end of file
diff --git a/src/test/java/org/apache/sysds/test/functions/io/csv/ReadFrameCSVTest1.java b/src/test/java/org/apache/sysds/test/functions/io/csv/ReadFrameCSVTest1.java
new file mode 100644
index 0000000..aa755a3
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/io/csv/ReadFrameCSVTest1.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.io.csv;
+
+import static org.junit.Assert.assertTrue;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.conf.CompilerConfig;
+import org.apache.sysds.test.TestConfiguration;
+import org.junit.Test;
+
+public class ReadFrameCSVTest1 extends CSVTestBase {
+
+ private final static String TEST_NAME = "ReadFrameTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + ReadFrameCSVTest1.class.getSimpleName() + "/";
+ private final static String[] expectedStrings = new String[] {"goodbye 2 four new york"};
+
+ protected String getInputCSVFileName() {
+ return "frame_" + getId();
+ }
+
+ protected int getId() {
+ return 1;
+ }
+
+ @Override
+ protected String getTestClassDir() {
+ return TEST_CLASS_DIR;
+ }
+
+ @Override
+ protected String getTestName() {
+ return TEST_NAME;
+ }
+
+ protected String[] getExpectedStrings() {
+ return expectedStrings;
+ }
+
+ // @Test
+ // public void testCSV_Sequential_CP1() {
+ // runCSVTest(getId(), ExecMode.SINGLE_NODE, false);
+ // }
+
+ // @Test
+ // public void testCSV_Parallel_CP1() {
+ // runCSVTest(getId(), ExecMode.SINGLE_NODE, true);
+ // }
+
+ // @Test
+ // public void testCSV_Sequential_CP() {
+ // runCSVTest(getId(), ExecMode.HYBRID, false);
+ // }
+
+ // @Test
+ // public void testCSV_Parallel_CP() {
+ // runCSVTest(getId(), ExecMode.HYBRID, true);
+ // }
+
+ @Test
+ public void testCSV_SP() {
+ runCSVTest(getId(), ExecMode.SPARK, false);
+ }
+
+ protected void runCSVTest(int testNumber, ExecMode platform, boolean parallel) {
+ ExecMode oldPlatform = rtplatform;
+ rtplatform = platform;
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if(rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ boolean oldpar = CompilerConfig.FLAG_PARREADWRITE_TEXT;
+ String output;
+ try {
+ CompilerConfig.FLAG_PARREADWRITE_TEXT = parallel;
+
+ TestConfiguration config = getTestConfiguration(getTestName());
+
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ String inputMatrixNameNoExtension = HOME + INPUT_DIR + getInputCSVFileName();
+ String inputMatrixNameWithExtension = inputMatrixNameNoExtension + ".csv";
+ String dmlOutput = output("dml.scalar");
+
+ fullDMLScriptName = HOME + getTestName() + "_" + testNumber + ".dml";
+ programArgs = new String[] {"-args", inputMatrixNameWithExtension, dmlOutput};
+
+ output = runTest(true, false, null, -1).toString();
+
+ }
+ finally {
+ rtplatform = oldPlatform;
+ CompilerConfig.FLAG_PARREADWRITE_TEXT = oldpar;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+
+ for(String str : getExpectedStrings()) {
+ assertTrue("\nout: " + output + "\n expected: " + str, output.contains(str));
+ }
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/io/csv/ReadFrameCSVTest2.java b/src/test/java/org/apache/sysds/test/functions/io/csv/ReadFrameCSVTest2.java
new file mode 100644
index 0000000..d54723e
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/io/csv/ReadFrameCSVTest2.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.io.csv;
+
+public class ReadFrameCSVTest2 extends ReadFrameCSVTest1 {
+
+ private final static String TEST_NAME = "ReadFrameTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + ReadFrameCSVTest2.class.getSimpleName() + "/";
+
+ private final static String[] expectedStrings = new String[] {"goodbye 2 four new york"};
+
+ @Override
+ protected String getInputCSVFileName() {
+ return "frame_" + getId();
+ }
+
+ @Override
+ protected int getId() {
+ return 2;
+ }
+
+ @Override
+ protected String getTestClassDir() {
+ return TEST_CLASS_DIR;
+ }
+
+ @Override
+ protected String getTestName() {
+ return TEST_NAME;
+ }
+
+ @Override
+ protected String[] getExpectedStrings() {
+ return expectedStrings;
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/io/csv/ReadFrameCSVTest3.java b/src/test/java/org/apache/sysds/test/functions/io/csv/ReadFrameCSVTest3.java
new file mode 100644
index 0000000..d6a70e4
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/io/csv/ReadFrameCSVTest3.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.io.csv;
+
+public class ReadFrameCSVTest3 extends ReadFrameCSVTest1 {
+
+ private final static String TEST_NAME = "ReadFrameTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + ReadFrameCSVTest3.class.getSimpleName() + "/";
+ private final static String[] expectedStrings = new String[] {"null 1 five null", "null 2 four new york"};
+
+ @Override
+ protected String getInputCSVFileName() {
+ return "frame_" + getId();
+ }
+
+ @Override
+ protected int getId() {
+ return 3;
+ }
+
+ @Override
+ protected String getTestClassDir() {
+ return TEST_CLASS_DIR;
+ }
+
+ @Override
+ protected String getTestName() {
+ return TEST_NAME;
+ }
+
+ @Override
+ protected String[] getExpectedStrings() {
+ return expectedStrings;
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/LineageMLContextTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/LineageMLContextTest.java
index 9ff0020..91579ef 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageMLContextTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/LineageMLContextTest.java
@@ -20,6 +20,7 @@
package org.apache.sysds.test.functions.lineage;
import static org.apache.sysds.api.mlcontext.ScriptFactory.dml;
+import static org.junit.Assert.assertTrue;
import java.util.ArrayList;
import java.util.List;
@@ -49,9 +50,10 @@
"print('sum: '+sum(M+M));"
+ "print(lineage(M+M));"
).in("M", javaRDD, mm);
- setExpectedStdOut("sum: 30.0");
ml.setLineage(ReuseCacheType.NONE);
+ String out = MLContextTestBase.executeAndCaptureStdOut(ml,script).getRight();
+ assertTrue(out.contains("sum: 30.0"));
ml.execute(script);
}
@@ -71,11 +73,12 @@
+ "s = lineage(M+M);"
+"if( sum(M) < 0 ) print(s);"
).in("M", javaRDD, mm);
- setExpectedStdOut("sum: 30.0");
ml.setLineage(ReuseCacheType.REUSE_FULL);
- ml.execute(script);
- ml.execute(script); //w/ reuse
+ String out = MLContextTestBase.executeAndCaptureStdOut(ml,script).getRight();
+ assertTrue(out.contains("sum: 30.0"));
+ out = MLContextTestBase.executeAndCaptureStdOut(ml,script).getRight();
+ assertTrue(out.contains("sum: 30.0"));
}
@Test
@@ -97,15 +100,15 @@
ml.setLineage(ReuseCacheType.REUSE_FULL);
- setExpectedStdOut("sum: 30.0");
- ml.execute(script);
+ String out = MLContextTestBase.executeAndCaptureStdOut(ml,script).getRight();
+ assertTrue(out.contains("sum: 30.0"));
list.add("4 4 5");
JavaRDD<String> javaRDD2 = sc.parallelize(list);
MatrixMetadata mm2 = new MatrixMetadata(MatrixFormat.IJV, 4, 4);
script.in("M", javaRDD2, mm2);
- setExpectedStdOut("sum: 40.0");
- ml.execute(script); //w/o reuse
+ out = MLContextTestBase.executeAndCaptureStdOut(ml,script).getRight();
+ assertTrue(out.contains("sum: 40.0"));
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
index 697e9e9..147417e 100644
--- a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
@@ -19,12 +19,12 @@
package org.apache.sysds.test.functions.mlcontext;
-import static org.junit.Assert.assertTrue;
import static org.apache.sysds.api.mlcontext.ScriptFactory.dml;
import static org.apache.sysds.api.mlcontext.ScriptFactory.dmlFromFile;
import static org.apache.sysds.api.mlcontext.ScriptFactory.dmlFromInputStream;
import static org.apache.sysds.api.mlcontext.ScriptFactory.dmlFromLocalFile;
import static org.apache.sysds.api.mlcontext.ScriptFactory.dmlFromUrl;
+import static org.junit.Assert.assertTrue;
import java.io.File;
import java.io.FileInputStream;
@@ -40,6 +40,8 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
@@ -55,8 +57,6 @@
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
-import org.junit.Assert;
-import org.junit.Test;
import org.apache.sysds.api.mlcontext.MLContextConversionUtil;
import org.apache.sysds.api.mlcontext.MLContextException;
import org.apache.sysds.api.mlcontext.MLContextUtil;
@@ -73,6 +73,8 @@
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
import scala.Tuple1;
import scala.Tuple2;
@@ -84,90 +86,92 @@
public class MLContextTest extends MLContextTestBase {
+ private static final Log LOG = LogFactory.getLog(MLContextTest.class.getName());
+
@Test
public void testBuiltinConstantsTest() {
- System.out.println("MLContextTest - basic builtin constants test");
+ LOG.debug("MLContextTest - basic builtin constants test");
Script script = dmlFromFile(baseDirectory + File.separator + "builtin-constants-test.dml");
- ml.execute(script);
+ executeAndCaptureStdOut(script);
Assert.assertTrue(Statistics.getNoOfExecutedSPInst() == 0);
}
-
+
@Test
public void testBasicExecuteEvalTest() {
- System.out.println("MLContextTest - basic eval test");
- setExpectedStdOut("10");
+ LOG.debug("MLContextTest - basic eval test");
Script script = dmlFromFile(baseDirectory + File.separator + "eval-test.dml");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("10"));
}
-
+
@Test
public void testRewriteExecuteEvalTest() {
- System.out.println("MLContextTest - eval rewrite test");
+ LOG.debug("MLContextTest - eval rewrite test");
Script script = dmlFromFile(baseDirectory + File.separator + "eval2-test.dml");
- ml.execute(script);
+ executeAndCaptureStdOut(script);
Assert.assertTrue(Statistics.getNoOfExecutedSPInst() == 0);
}
-
+
@Test
public void testExecuteEvalBuiltinTest() {
- System.out.println("MLContextTest - eval builtin test");
- setExpectedStdOut("TRUE");
- ml.setExplain(true);
+ LOG.debug("MLContextTest - eval builtin test");
Script script = dmlFromFile(baseDirectory + File.separator + "eval3-builtin-test.dml");
- ml.execute(script);
+ ml.setExplain(true);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("TRUE"));
ml.setExplain(false);
}
-
+
@Test
public void testExecuteEvalNestedBuiltinTest() {
- System.out.println("MLContextTest - eval builtin test");
- setExpectedStdOut("TRUE");
- ml.setExplain(true);
+ LOG.debug("MLContextTest - eval builtin test");
Script script = dmlFromFile(baseDirectory + File.separator + "eval4-nested_builtin-test.dml");
- ml.execute(script);
+ ml.setExplain(true);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("TRUE"));
ml.setExplain(false);
}
@Test
public void testCreateDMLScriptBasedOnStringAndExecute() {
- System.out.println("MLContextTest - create DML script based on string and execute");
+ LOG.debug("MLContextTest - create DML script based on string and execute");
String testString = "Create DML script based on string and execute";
- setExpectedStdOut(testString);
Script script = dml("print('" + testString + "');");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains(testString));
}
@Test
public void testCreateDMLScriptBasedOnFileAndExecute() {
- System.out.println("MLContextTest - create DML script based on file and execute");
- setExpectedStdOut("hello world");
+ LOG.debug("MLContextTest - create DML script based on file and execute");
Script script = dmlFromFile(baseDirectory + File.separator + "hello-world.dml");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("hello world"));
}
@Test
public void testCreateDMLScriptBasedOnInputStreamAndExecute() throws IOException {
- System.out.println("MLContextTest - create DML script based on InputStream and execute");
- setExpectedStdOut("hello world");
+ LOG.debug("MLContextTest - create DML script based on InputStream and execute");
File file = new File(baseDirectory + File.separator + "hello-world.dml");
- try( InputStream is = new FileInputStream(file) ) {
+ try(InputStream is = new FileInputStream(file)) {
Script script = dmlFromInputStream(is);
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("hello world"));
}
}
@Test
public void testCreateDMLScriptBasedOnLocalFileAndExecute() {
- System.out.println("MLContextTest - create DML script based on local file and execute");
- setExpectedStdOut("hello world");
+ LOG.debug("MLContextTest - create DML script based on local file and execute");
File file = new File(baseDirectory + File.separator + "hello-world.dml");
Script script = dmlFromLocalFile(file);
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("hello world"));
}
@Test
public void testCreateDMLScriptBasedOnURL() throws MalformedURLException {
- System.out.println("MLContextTest - create DML script based on URL");
+ LOG.debug("MLContextTest - create DML script based on URL");
String urlString = "https://raw.githubusercontent.com/apache/systemml/master/src/test/scripts/applications/hits/HITS.dml";
URL url = new URL(urlString);
Script script = dmlFromUrl(url);
@@ -178,7 +182,7 @@
@Test
public void testCreateDMLScriptBasedOnURLString() {
- System.out.println("MLContextTest - create DML script based on URL string");
+ LOG.debug("MLContextTest - create DML script based on URL string");
String urlString = "https://raw.githubusercontent.com/apache/systemml/master/src/test/scripts/applications/hits/HITS.dml";
Script script = dmlFromUrl(urlString);
String expectedContent = "Licensed to the Apache Software Foundation";
@@ -188,26 +192,26 @@
@Test
public void testExecuteDMLScript() {
- System.out.println("MLContextTest - execute DML script");
+ LOG.debug("MLContextTest - execute DML script");
String testString = "hello dml world!";
- setExpectedStdOut(testString);
Script script = new Script("print('" + testString + "');");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains(testString));
}
@Test
public void testInputParametersAddDML() {
- System.out.println("MLContextTest - input parameters add DML");
+ LOG.debug("MLContextTest - input parameters add DML");
String s = "x = $X; y = $Y; print('x + y = ' + (x + y));";
Script script = dml(s).in("$X", 3).in("$Y", 4);
- setExpectedStdOut("x + y = 7");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("x + y = 7"));
}
@Test
public void testJavaRDDCSVSumDML() {
- System.out.println("MLContextTest - JavaRDD<String> CSV sum DML");
+ LOG.debug("MLContextTest - JavaRDD<String> CSV sum DML");
List<String> list = new ArrayList<>();
list.add("1,2,3");
@@ -216,13 +220,13 @@
JavaRDD<String> javaRDD = sc.parallelize(list);
Script script = dml("print('sum: ' + sum(M));").in("M", javaRDD);
- setExpectedStdOut("sum: 45.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 45.0"));
}
@Test
public void testJavaRDDIJVSumDML() {
- System.out.println("MLContextTest - JavaRDD<String> IJV sum DML");
+ LOG.debug("MLContextTest - JavaRDD<String> IJV sum DML");
List<String> list = new ArrayList<>();
list.add("1 1 5");
@@ -233,13 +237,13 @@
MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, 3, 3);
Script script = dml("print('sum: ' + sum(M));").in("M", javaRDD, mm);
- setExpectedStdOut("sum: 15.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 15.0"));
}
@Test
public void testJavaRDDAndInputParameterDML() {
- System.out.println("MLContextTest - JavaRDD<String> and input parameter DML");
+ LOG.debug("MLContextTest - JavaRDD<String> and input parameter DML");
List<String> list = new ArrayList<>();
list.add("1,2");
@@ -248,13 +252,13 @@
String s = "M = M + $X; print('sum: ' + sum(M));";
Script script = dml(s).in("M", javaRDD).in("$X", 1);
- setExpectedStdOut("sum: 14.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 14.0"));
}
@Test
public void testInputMapDML() {
- System.out.println("MLContextTest - input map DML");
+ LOG.debug("MLContextTest - input map DML");
List<String> list = new ArrayList<>();
list.add("10,20");
@@ -271,27 +275,28 @@
String s = "M = M + $X; print('sum: ' + sum(M));";
Script script = dml(s).in(inputs);
- setExpectedStdOut("sum: 108.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 108.0"));
}
@Test
public void testCustomExecutionStepDML() {
- System.out.println("MLContextTest - custom execution step DML");
+ LOG.debug("MLContextTest - custom execution step DML");
String testString = "custom execution step";
- setExpectedStdOut(testString);
Script script = new Script("print('" + testString + "');");
ScriptExecutor scriptExecutor = new ScriptExecutor() {
@Override
- protected void showExplanation() {}
+ protected void showExplanation() {
+ }
};
- ml.execute(script, scriptExecutor);
+ String out = executeAndCaptureStdOut(ml, script, scriptExecutor).getRight();
+ assertTrue(out.contains(testString));
}
@Test
public void testRDDSumCSVDML() {
- System.out.println("MLContextTest - RDD<String> CSV sum DML");
+ LOG.debug("MLContextTest - RDD<String> CSV sum DML");
List<String> list = new ArrayList<>();
list.add("1,1,1");
@@ -301,13 +306,13 @@
RDD<String> rdd = JavaRDD.toRDD(javaRDD);
Script script = dml("print('sum: ' + sum(M));").in("M", rdd);
- setExpectedStdOut("sum: 18.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 18.0"));
}
@Test
public void testRDDSumIJVDML() {
- System.out.println("MLContextTest - RDD<String> IJV sum DML");
+ LOG.debug("MLContextTest - RDD<String> IJV sum DML");
List<String> list = new ArrayList<>();
list.add("1 1 1");
@@ -320,13 +325,13 @@
MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, 3, 3);
Script script = dml("print('sum: ' + sum(M));").in("M", rdd, mm);
- setExpectedStdOut("sum: 10.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 10.0"));
}
@Test
public void testDataFrameSumDMLDoublesWithNoIDColumn() {
- System.out.println("MLContextTest - DataFrame sum DML, doubles with no ID column");
+ LOG.debug("MLContextTest - DataFrame sum DML, doubles with no ID column");
List<String> list = new ArrayList<>();
list.add("10,20,30");
@@ -345,13 +350,13 @@
MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES);
Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
- setExpectedStdOut("sum: 450.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 450.0"));
}
@Test
public void testDataFrameSumDMLDoublesWithIDColumn() {
- System.out.println("MLContextTest - DataFrame sum DML, doubles with ID column");
+ LOG.debug("MLContextTest - DataFrame sum DML, doubles with ID column");
List<String> list = new ArrayList<>();
list.add("1,1,2,3");
@@ -371,13 +376,13 @@
MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_INDEX);
Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
- setExpectedStdOut("sum: 45.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 45.0"));
}
@Test
public void testDataFrameSumDMLDoublesWithIDColumnSortCheck() {
- System.out.println("MLContextTest - DataFrame sum DML, doubles with ID column sort check");
+ LOG.debug("MLContextTest - DataFrame sum DML, doubles with ID column sort check");
List<String> list = new ArrayList<>();
list.add("3,7,8,9");
@@ -397,13 +402,13 @@
MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_INDEX);
Script script = dml("print('M[1,1]: ' + as.scalar(M[1,1]));").in("M", dataFrame, mm);
- setExpectedStdOut("M[1,1]: 1.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("M[1,1]: 1.0"));
}
@Test
public void testDataFrameSumDMLVectorWithIDColumn() {
- System.out.println("MLContextTest - DataFrame sum DML, vector with ID column");
+ LOG.debug("MLContextTest - DataFrame sum DML, vector with ID column");
List<Tuple2<Double, Vector>> list = new ArrayList<>();
list.add(new Tuple2<>(1.0, Vectors.dense(1.0, 2.0, 3.0)));
@@ -421,13 +426,13 @@
MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_INDEX);
Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
- setExpectedStdOut("sum: 45.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 45.0"));
}
@Test
public void testDataFrameSumDMLMllibVectorWithIDColumn() {
- System.out.println("MLContextTest - DataFrame sum DML, mllib vector with ID column");
+ LOG.debug("MLContextTest - DataFrame sum DML, mllib vector with ID column");
List<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> list = new ArrayList<>();
list.add(new Tuple2<>(1.0, org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0)));
@@ -445,13 +450,13 @@
MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_INDEX);
Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
- setExpectedStdOut("sum: 45.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 45.0"));
}
@Test
public void testDataFrameSumDMLVectorWithNoIDColumn() {
- System.out.println("MLContextTest - DataFrame sum DML, vector with no ID column");
+ LOG.debug("MLContextTest - DataFrame sum DML, vector with no ID column");
List<Vector> list = new ArrayList<>();
list.add(Vectors.dense(1.0, 2.0, 3.0));
@@ -468,13 +473,13 @@
MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR);
Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
- setExpectedStdOut("sum: 45.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 45.0"));
}
@Test
public void testDataFrameSumDMLMllibVectorWithNoIDColumn() {
- System.out.println("MLContextTest - DataFrame sum DML, mllib vector with no ID column");
+ LOG.debug("MLContextTest - DataFrame sum DML, mllib vector with no ID column");
List<org.apache.spark.mllib.linalg.Vector> list = new ArrayList<>();
list.add(org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0));
@@ -491,8 +496,8 @@
MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR);
Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
- setExpectedStdOut("sum: 45.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 45.0"));
}
static class DoubleVectorRow implements Function<Tuple2<Double, Vector>, Row> {
@@ -552,7 +557,7 @@
public Row call(String str) throws Exception {
String[] strings = str.split(",");
Double[] doubles = new Double[strings.length];
- for (int i = 0; i < strings.length; i++) {
+ for(int i = 0; i < strings.length; i++) {
doubles[i] = Double.parseDouble(strings[i]);
}
return RowFactory.create((Object[]) doubles);
@@ -561,47 +566,48 @@
@Test
public void testCSVMatrixFileInputParameterSumDML() {
- System.out.println("MLContextTest - CSV matrix file input parameter sum DML");
+ LOG.debug("MLContextTest - CSV matrix file input parameter sum DML");
String s = "M = read($Min); print('sum: ' + sum(M));";
String csvFile = baseDirectory + File.separator + "1234.csv";
- setExpectedStdOut("sum: 10.0");
- ml.execute(dml(s).in("$Min", csvFile));
+ String out = executeAndCaptureStdOut(ml, dml(s).in("$Min", csvFile)).getRight();
+ assertTrue(out.contains("sum: 10.0"));
+
}
@Test
public void testCSVMatrixFileInputVariableSumDML() {
- System.out.println("MLContextTest - CSV matrix file input variable sum DML");
+ LOG.debug("MLContextTest - CSV matrix file input variable sum DML");
String s = "M = read($Min); print('sum: ' + sum(M));";
String csvFile = baseDirectory + File.separator + "1234.csv";
- setExpectedStdOut("sum: 10.0");
- ml.execute(dml(s).in("$Min", csvFile));
+ String out = executeAndCaptureStdOut(ml, dml(s).in("$Min", csvFile)).getRight();
+ assertTrue(out.contains("sum: 10.0"));
}
@Test
public void test2DDoubleSumDML() {
- System.out.println("MLContextTest - two-dimensional double array sum DML");
+ LOG.debug("MLContextTest - two-dimensional double array sum DML");
- double[][] matrix = new double[][] { { 10.0, 20.0 }, { 30.0, 40.0 } };
+ double[][] matrix = new double[][] {{10.0, 20.0}, {30.0, 40.0}};
Script script = dml("print('sum: ' + sum(M));").in("M", matrix);
- setExpectedStdOut("sum: 100.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 100.0"));
}
@Test
public void testAddScalarIntegerInputsDML() {
- System.out.println("MLContextTest - add scalar integer inputs DML");
+ LOG.debug("MLContextTest - add scalar integer inputs DML");
String s = "total = in1 + in2; print('total: ' + total);";
Script script = dml(s).in("in1", 1).in("in2", 2);
- setExpectedStdOut("total: 3");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("total: 3"));
}
@Test
public void testInputScalaMapDML() {
- System.out.println("MLContextTest - input Scala map DML");
+ LOG.debug("MLContextTest - input Scala map DML");
List<String> list = new ArrayList<>();
list.add("10,20");
@@ -620,15 +626,15 @@
String s = "M = M + $X; print('sum: ' + sum(M));";
Script script = dml(s).in(scalaMap);
- setExpectedStdOut("sum: 108.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 108.0"));
}
@Test
public void testOutputDoubleArrayMatrixDML() {
- System.out.println("MLContextTest - output double array matrix DML");
+ LOG.debug("MLContextTest - output double array matrix DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
- double[][] matrix = ml.execute(dml(s).out("M")).getMatrixAs2DDoubleArray("M");
+ double[][] matrix = executeAndCaptureStdOut(dml(s).out("M")).getLeft().getMatrixAs2DDoubleArray("M");
Assert.assertEquals(1.0, matrix[0][0], 0);
Assert.assertEquals(2.0, matrix[0][1], 0);
Assert.assertEquals(3.0, matrix[1][0], 0);
@@ -637,54 +643,54 @@
@Test
public void testOutputScalarLongDML() {
- System.out.println("MLContextTest - output scalar long DML");
+ LOG.debug("MLContextTest - output scalar long DML");
String s = "m = 5;";
- long result = ml.execute(dml(s).out("m")).getLong("m");
+ long result = executeAndCaptureStdOut(dml(s).out("m")).getLeft().getLong("m");
Assert.assertEquals(5, result);
}
@Test
public void testOutputScalarDoubleDML() {
- System.out.println("MLContextTest - output scalar double DML");
+ LOG.debug("MLContextTest - output scalar double DML");
String s = "m = 1.23";
- double result = ml.execute(dml(s).out("m")).getDouble("m");
+ double result = executeAndCaptureStdOut(dml(s).out("m")).getLeft().getDouble("m");
Assert.assertEquals(1.23, result, 0);
}
@Test
public void testOutputScalarBooleanDML() {
- System.out.println("MLContextTest - output scalar boolean DML");
+ LOG.debug("MLContextTest - output scalar boolean DML");
String s = "m = FALSE;";
- boolean result = ml.execute(dml(s).out("m")).getBoolean("m");
+ boolean result = executeAndCaptureStdOut(dml(s).out("m")).getLeft().getBoolean("m");
Assert.assertEquals(false, result);
}
@Test
public void testOutputScalarStringDML() {
- System.out.println("MLContextTest - output scalar string DML");
+ LOG.debug("MLContextTest - output scalar string DML");
String s = "m = 'hello';";
- String result = ml.execute(dml(s).out("m")).getString("m");
+ String result = executeAndCaptureStdOut(dml(s).out("m")).getLeft().getString("m");
Assert.assertEquals("hello", result);
}
@Test
public void testInputFrameDML() {
- System.out.println("MLContextTest - input frame DML");
+ LOG.debug("MLContextTest - input frame DML");
String s = "M = read($Min, data_type='frame', format='csv'); print(toString(M));";
String csvFile = baseDirectory + File.separator + "one-two-three-four.csv";
Script script = dml(s).in("$Min", csvFile);
- setExpectedStdOut("one");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("one"));
}
@Test
public void testOutputJavaRDDStringIJVDML() {
- System.out.println("MLContextTest - output Java RDD String IJV DML");
+ LOG.debug("MLContextTest - output Java RDD String IJV DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("M");
List<String> lines = javaRDDStringIJV.collect();
Assert.assertEquals("1 1 1.0", lines.get(0));
@@ -695,11 +701,11 @@
@Test
public void testOutputJavaRDDStringCSVDenseDML() {
- System.out.println("MLContextTest - output Java RDD String CSV Dense DML");
+ LOG.debug("MLContextTest - output Java RDD String CSV Dense DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2); print(toString(M));";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
JavaRDD<String> javaRDDStringCSV = results.getJavaRDDStringCSV("M");
List<String> lines = javaRDDStringCSV.collect();
Assert.assertEquals("1.0,2.0", lines.get(0));
@@ -707,16 +713,16 @@
}
/**
- * Reading from dense and sparse matrices is handled differently, so we have
- * tests for both dense and sparse matrices.
+ * Reading from dense and sparse matrices is handled differently, so we have tests for both dense and sparse
+ * matrices.
*/
@Test
public void testOutputJavaRDDStringCSVSparseDML() {
- System.out.println("MLContextTest - output Java RDD String CSV Sparse DML");
+ LOG.debug("MLContextTest - output Java RDD String CSV Sparse DML");
String s = "M = matrix(0, rows=10, cols=10); M[1,1]=1; M[1,2]=2; M[2,1]=3; M[2,2]=4; print(toString(M));";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
JavaRDD<String> javaRDDStringCSV = results.getJavaRDDStringCSV("M");
List<String> lines = javaRDDStringCSV.collect();
Assert.assertEquals("1.0,2.0", lines.get(0));
@@ -725,11 +731,11 @@
@Test
public void testOutputRDDStringIJVDML() {
- System.out.println("MLContextTest - output RDD String IJV DML");
+ LOG.debug("MLContextTest - output RDD String IJV DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
RDD<String> rddStringIJV = results.getRDDStringIJV("M");
Iterator<String> iterator = rddStringIJV.toLocalIterator();
Assert.assertEquals("1 1 1.0", iterator.next());
@@ -740,11 +746,11 @@
@Test
public void testOutputRDDStringCSVDenseDML() {
- System.out.println("MLContextTest - output RDD String CSV Dense DML");
+ LOG.debug("MLContextTest - output RDD String CSV Dense DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2); print(toString(M));";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
RDD<String> rddStringCSV = results.getRDDStringCSV("M");
Iterator<String> iterator = rddStringCSV.toLocalIterator();
Assert.assertEquals("1.0,2.0", iterator.next());
@@ -753,11 +759,11 @@
@Test
public void testOutputRDDStringCSVSparseDML() {
- System.out.println("MLContextTest - output RDD String CSV Sparse DML");
+ LOG.debug("MLContextTest - output RDD String CSV Sparse DML");
String s = "M = matrix(0, rows=10, cols=10); M[1,1]=1; M[1,2]=2; M[2,1]=3; M[2,2]=4; print(toString(M));";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
RDD<String> rddStringCSV = results.getRDDStringCSV("M");
Iterator<String> iterator = rddStringCSV.toLocalIterator();
Assert.assertEquals("1.0,2.0", iterator.next());
@@ -766,11 +772,11 @@
@Test
public void testOutputDataFrameDML() {
- System.out.println("MLContextTest - output DataFrame DML");
+ LOG.debug("MLContextTest - output DataFrame DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
Dataset<Row> dataFrame = results.getDataFrame("M");
List<Row> list = dataFrame.collectAsList();
Row row1 = list.get(0);
@@ -786,47 +792,47 @@
@Test
public void testOutputDataFrameDMLVectorWithIDColumn() {
- System.out.println("MLContextTest - output DataFrame DML, vector with ID column");
+ LOG.debug("MLContextTest - output DataFrame DML, vector with ID column");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
Dataset<Row> dataFrame = results.getDataFrameVectorWithIDColumn("M");
List<Row> list = dataFrame.collectAsList();
Row row1 = list.get(0);
Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
- Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, ((Vector) row1.get(1)).toArray(), 0.0);
+ Assert.assertArrayEquals(new double[] {1.0, 2.0}, ((Vector) row1.get(1)).toArray(), 0.0);
Row row2 = list.get(1);
Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
- Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, ((Vector) row2.get(1)).toArray(), 0.0);
+ Assert.assertArrayEquals(new double[] {3.0, 4.0}, ((Vector) row2.get(1)).toArray(), 0.0);
}
@Test
public void testOutputDataFrameDMLVectorNoIDColumn() {
- System.out.println("MLContextTest - output DataFrame DML, vector no ID column");
+ LOG.debug("MLContextTest - output DataFrame DML, vector no ID column");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
Dataset<Row> dataFrame = results.getDataFrameVectorNoIDColumn("M");
List<Row> list = dataFrame.collectAsList();
Row row1 = list.get(0);
- Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, ((Vector) row1.get(0)).toArray(), 0.0);
+ Assert.assertArrayEquals(new double[] {1.0, 2.0}, ((Vector) row1.get(0)).toArray(), 0.0);
Row row2 = list.get(1);
- Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, ((Vector) row2.get(0)).toArray(), 0.0);
+ Assert.assertArrayEquals(new double[] {3.0, 4.0}, ((Vector) row2.get(0)).toArray(), 0.0);
}
@Test
public void testOutputDataFrameDMLDoublesWithIDColumn() {
- System.out.println("MLContextTest - output DataFrame DML, doubles with ID column");
+ LOG.debug("MLContextTest - output DataFrame DML, doubles with ID column");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
Dataset<Row> dataFrame = results.getDataFrameDoubleWithIDColumn("M");
List<Row> list = dataFrame.collectAsList();
@@ -843,11 +849,11 @@
@Test
public void testOutputDataFrameDMLDoublesNoIDColumn() {
- System.out.println("MLContextTest - output DataFrame DML, doubles no ID column");
+ LOG.debug("MLContextTest - output DataFrame DML, doubles no ID column");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
Dataset<Row> dataFrame = results.getDataFrameDoubleNoIDColumn("M");
List<Row> list = dataFrame.collectAsList();
@@ -862,55 +868,55 @@
@Test
public void testTwoScriptsDML() {
- System.out.println("MLContextTest - two scripts with inputs and outputs DML");
+ LOG.debug("MLContextTest - two scripts with inputs and outputs DML");
- double[][] m1 = new double[][] { { 1.0, 2.0 }, { 3.0, 4.0 } };
+ double[][] m1 = new double[][] {{1.0, 2.0}, {3.0, 4.0}};
String s1 = "sum1 = sum(m1);";
- double sum1 = ml.execute(dml(s1).in("m1", m1).out("sum1")).getDouble("sum1");
+ double sum1 = executeAndCaptureStdOut(dml(s1).in("m1", m1).out("sum1")).getLeft().getDouble("sum1");
Assert.assertEquals(10.0, sum1, 0.0);
- double[][] m2 = new double[][] { { 5.0, 6.0 }, { 7.0, 8.0 } };
+ double[][] m2 = new double[][] {{5.0, 6.0}, {7.0, 8.0}};
String s2 = "sum2 = sum(m2);";
- double sum2 = ml.execute(dml(s2).in("m2", m2).out("sum2")).getDouble("sum2");
+ double sum2 = executeAndCaptureStdOut(dml(s2).in("m2", m2).out("sum2")).getLeft().getDouble("sum2");
Assert.assertEquals(26.0, sum2, 0.0);
}
@Test
public void testOneScriptTwoExecutionsDML() {
- System.out.println("MLContextTest - one script with two executions DML");
+ LOG.debug("MLContextTest - one script with two executions DML");
Script script = new Script();
- double[][] m1 = new double[][] { { 1.0, 2.0 }, { 3.0, 4.0 } };
+ double[][] m1 = new double[][] {{1.0, 2.0}, {3.0, 4.0}};
script.setScriptString("sum1 = sum(m1);").in("m1", m1).out("sum1");
- ml.execute(script);
+ executeAndCaptureStdOut(script);
Assert.assertEquals(10.0, script.results().getDouble("sum1"), 0.0);
script.clearAll();
- double[][] m2 = new double[][] { { 5.0, 6.0 }, { 7.0, 8.0 } };
+ double[][] m2 = new double[][] {{5.0, 6.0}, {7.0, 8.0}};
script.setScriptString("sum2 = sum(m2);").in("m2", m2).out("sum2");
- ml.execute(script);
+ executeAndCaptureStdOut(script);
Assert.assertEquals(26.0, script.results().getDouble("sum2"), 0.0);
}
@Test
public void testInputParameterBooleanDML() {
- System.out.println("MLContextTest - input parameter boolean DML");
+ LOG.debug("MLContextTest - input parameter boolean DML");
String s = "x = $X; if (x == TRUE) { print('yes'); }";
Script script = dml(s).in("$X", true);
- setExpectedStdOut("yes");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("yes"));
}
@Test
public void testMultipleOutDML() {
- System.out.println("MLContextTest - multiple out DML");
+ LOG.debug("MLContextTest - multiple out DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2); N = sum(M)";
// alternative to .out("M").out("N")
- MLResults results = ml.execute(dml(s).out("M", "N"));
+ MLResults results = executeAndCaptureStdOut(dml(s).out("M", "N")).getLeft();
double[][] matrix = results.getMatrixAs2DDoubleArray("M");
double sum = results.getDouble("N");
Assert.assertEquals(1.0, matrix[0][0], 0);
@@ -922,9 +928,9 @@
@Test
public void testOutputMatrixObjectDML() {
- System.out.println("MLContextTest - output matrix object DML");
+ LOG.debug("MLContextTest - output matrix object DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
- MatrixObject mo = ml.execute(dml(s).out("M")).getMatrixObject("M");
+ MatrixObject mo = executeAndCaptureStdOut(dml(s).out("M")).getLeft().getMatrixObject("M");
RDD<String> rddStringCSV = MLContextConversionUtil.matrixObjectToRDDStringCSV(mo);
Iterator<String> iterator = rddStringCSV.toLocalIterator();
Assert.assertEquals("1.0,2.0", iterator.next());
@@ -933,7 +939,7 @@
@Test
public void testInputMatrixBlockDML() {
- System.out.println("MLContextTest - input MatrixBlock DML");
+ LOG.debug("MLContextTest - input MatrixBlock DML");
List<String> list = new ArrayList<>();
list.add("10,20,30");
@@ -952,15 +958,15 @@
Matrix m = new Matrix(dataFrame);
MatrixBlock matrixBlock = m.toMatrixBlock();
Script script = dml("avg = avg(M);").in("M", matrixBlock).out("avg");
- double avg = ml.execute(script).getDouble("avg");
+ double avg = executeAndCaptureStdOut(script).getLeft().getDouble("avg");
Assert.assertEquals(50.0, avg, 0.0);
}
@Test
public void testOutputBinaryBlocksDML() {
- System.out.println("MLContextTest - output binary blocks DML");
+ LOG.debug("MLContextTest - output binary blocks DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
- MLResults results = ml.execute(dml(s).out("M"));
+ MLResults results = executeAndCaptureStdOut(dml(s).out("M")).getLeft();
Matrix m = results.getMatrix("M");
JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks = m.toBinaryBlocks();
MatrixMetadata mm = m.getMatrixMetadata();
@@ -976,11 +982,11 @@
@Test
public void testOutputListStringCSVDenseDML() {
- System.out.println("MLContextTest - output List String CSV Dense DML");
+ LOG.debug("MLContextTest - output List String CSV Dense DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2); print(toString(M));";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
MatrixObject mo = results.getMatrixObject("M");
List<String> lines = MLContextConversionUtil.matrixObjectToListStringCSV(mo);
Assert.assertEquals("1.0,2.0", lines.get(0));
@@ -989,11 +995,11 @@
@Test
public void testOutputListStringCSVSparseDML() {
- System.out.println("MLContextTest - output List String CSV Sparse DML");
+ LOG.debug("MLContextTest - output List String CSV Sparse DML");
String s = "M = matrix(0, rows=10, cols=10); M[1,1]=1; M[1,2]=2; M[2,1]=3; M[2,2]=4; print(toString(M));";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
MatrixObject mo = results.getMatrixObject("M");
List<String> lines = MLContextConversionUtil.matrixObjectToListStringCSV(mo);
Assert.assertEquals("1.0,2.0", lines.get(0));
@@ -1002,11 +1008,11 @@
@Test
public void testOutputListStringIJVDenseDML() {
- System.out.println("MLContextTest - output List String IJV Dense DML");
+ LOG.debug("MLContextTest - output List String IJV Dense DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2); print(toString(M));";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
MatrixObject mo = results.getMatrixObject("M");
List<String> lines = MLContextConversionUtil.matrixObjectToListStringIJV(mo);
Assert.assertEquals("1 1 1.0", lines.get(0));
@@ -1017,11 +1023,11 @@
@Test
public void testOutputListStringIJVSparseDML() {
- System.out.println("MLContextTest - output List String IJV Sparse DML");
+ LOG.debug("MLContextTest - output List String IJV Sparse DML");
String s = "M = matrix(0, rows=10, cols=10); M[1,1]=1; M[1,2]=2; M[2,1]=3; M[2,2]=4; print(toString(M));";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
MatrixObject mo = results.getMatrixObject("M");
List<String> lines = MLContextConversionUtil.matrixObjectToListStringIJV(mo);
Assert.assertEquals("1 1 1.0", lines.get(0));
@@ -1032,7 +1038,7 @@
@Test
public void testJavaRDDGoodMetadataDML() {
- System.out.println("MLContextTest - JavaRDD<String> good metadata DML");
+ LOG.debug("MLContextTest - JavaRDD<String> good metadata DML");
List<String> list = new ArrayList<>();
list.add("1,2,3");
@@ -1043,13 +1049,13 @@
MatrixMetadata mm = new MatrixMetadata(3, 3, 9);
Script script = dml("print('sum: ' + sum(M));").in("M", javaRDD, mm);
- setExpectedStdOut("sum: 45.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 45.0"));
}
- @Test(expected = MLContextException.class)
+ @Test
public void testJavaRDDBadMetadataDML() {
- System.out.println("MLContextTest - JavaRDD<String> bad metadata DML");
+ LOG.debug("MLContextTest - JavaRDD<String> bad metadata DML");
List<String> list = new ArrayList<>();
list.add("1,2,3");
@@ -1060,12 +1066,12 @@
MatrixMetadata mm = new MatrixMetadata(1, 1, 9);
Script script = dml("print('sum: ' + sum(M));").in("M", javaRDD, mm);
- ml.execute(script);
+ executeAndCaptureStdOut(script, MLContextException.class);
}
@Test
public void testRDDGoodMetadataDML() {
- System.out.println("MLContextTest - RDD<String> good metadata DML");
+ LOG.debug("MLContextTest - RDD<String> good metadata DML");
List<String> list = new ArrayList<>();
list.add("1,1,1");
@@ -1077,13 +1083,13 @@
MatrixMetadata mm = new MatrixMetadata(3, 3, 9);
Script script = dml("print('sum: ' + sum(M));").in("M", rdd, mm);
- setExpectedStdOut("sum: 18.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 18.0"));
}
@Test
public void testDataFrameGoodMetadataDML() {
- System.out.println("MLContextTest - DataFrame good metadata DML");
+ LOG.debug("MLContextTest - DataFrame good metadata DML");
List<String> list = new ArrayList<>();
list.add("10,20,30");
@@ -1102,14 +1108,14 @@
MatrixMetadata mm = new MatrixMetadata(3, 3, 9);
Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
- setExpectedStdOut("sum: 450.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 450.0"));
}
- @SuppressWarnings({ "rawtypes", "unchecked" })
+ @SuppressWarnings({"rawtypes", "unchecked"})
@Test
public void testInputTupleSeqNoMetadataDML() {
- System.out.println("MLContextTest - Tuple sequence no metadata DML");
+ LOG.debug("MLContextTest - Tuple sequence no metadata DML");
List<String> list1 = new ArrayList<>();
list1.add("1,2");
@@ -1131,14 +1137,15 @@
Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq();
Script script = dml("print('sums: ' + sum(m1) + ' ' + sum(m2));").in(seq);
- setExpectedStdOut("sums: 10.0 26.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sums: 10.0 26.0"));
+ executeAndCaptureStdOut(script);
}
- @SuppressWarnings({ "rawtypes", "unchecked" })
+ @SuppressWarnings({"rawtypes", "unchecked"})
@Test
public void testInputTupleSeqWithMetadataDML() {
- System.out.println("MLContextTest - Tuple sequence with metadata DML");
+ LOG.debug("MLContextTest - Tuple sequence with metadata DML");
List<String> list1 = new ArrayList<>();
list1.add("1,2");
@@ -1163,34 +1170,34 @@
Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq();
Script script = dml("print('sums: ' + sum(m1) + ' ' + sum(m2));").in(seq);
- setExpectedStdOut("sums: 10.0 26.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sums: 10.0 26.0"));
}
@Test
public void testCSVMatrixFromURLSumDML() throws MalformedURLException {
- System.out.println("MLContextTest - CSV matrix from URL sum DML");
+ LOG.debug("MLContextTest - CSV matrix from URL sum DML");
String csv = "https://raw.githubusercontent.com/apache/systemml/master/src/test/scripts/functions/mlcontext/1234.csv";
URL url = new URL(csv);
Script script = dml("print('sum: ' + sum(M));").in("M", url);
- setExpectedStdOut("sum: 10.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 10.0"));
}
@Test
public void testIJVMatrixFromURLSumDML() throws MalformedURLException {
- System.out.println("MLContextTest - IJV matrix from URL sum DML");
+ LOG.debug("MLContextTest - IJV matrix from URL sum DML");
String ijv = "https://raw.githubusercontent.com/apache/systemml/master/src/test/scripts/functions/mlcontext/1234.ijv";
URL url = new URL(ijv);
MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, 2, 2);
Script script = dml("print('sum: ' + sum(M));").in("M", url, mm);
- setExpectedStdOut("sum: 10.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 10.0"));
}
@Test
public void testDataFrameSumDMLDoublesWithNoIDColumnNoFormatSpecified() {
- System.out.println("MLContextTest - DataFrame sum DML, doubles with no ID column, no format specified");
+ LOG.debug("MLContextTest - DataFrame sum DML, doubles with no ID column, no format specified");
List<String> list = new ArrayList<>();
list.add("2,2,2");
@@ -1207,13 +1214,13 @@
Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);
Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame);
- setExpectedStdOut("sum: 27.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 27.0"));
}
@Test
public void testDataFrameSumDMLDoublesWithIDColumnNoFormatSpecified() {
- System.out.println("MLContextTest - DataFrame sum DML, doubles with ID column, no format specified");
+ LOG.debug("MLContextTest - DataFrame sum DML, doubles with ID column, no format specified");
List<String> list = new ArrayList<>();
list.add("1,2,2,2");
@@ -1231,13 +1238,13 @@
Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);
Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame);
- setExpectedStdOut("sum: 27.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 27.0"));
}
@Test
public void testDataFrameSumDMLVectorWithIDColumnNoFormatSpecified() {
- System.out.println("MLContextTest - DataFrame sum DML, vector with ID column, no format specified");
+ LOG.debug("MLContextTest - DataFrame sum DML, vector with ID column, no format specified");
List<Tuple2<Double, Vector>> list = new ArrayList<>();
list.add(new Tuple2<>(1.0, Vectors.dense(1.0, 2.0, 3.0)));
@@ -1253,13 +1260,13 @@
Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);
Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame);
- setExpectedStdOut("sum: 45.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 45.0"));
}
@Test
public void testDataFrameSumDMLVectorWithNoIDColumnNoFormatSpecified() {
- System.out.println("MLContextTest - DataFrame sum DML, vector with no ID column, no format specified");
+ LOG.debug("MLContextTest - DataFrame sum DML, vector with no ID column, no format specified");
List<Vector> list = new ArrayList<>();
list.add(Vectors.dense(1.0, 2.0, 3.0));
@@ -1274,225 +1281,225 @@
Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);
Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame);
- setExpectedStdOut("sum: 45.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("sum: 45.0"));
}
@Test
public void testDisplayBooleanDML() {
- System.out.println("MLContextTest - display boolean DML");
+ LOG.debug("MLContextTest - display boolean DML");
String s = "print(b);";
Script script = dml(s).in("b", true);
- setExpectedStdOut("TRUE");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("TRUE"));
}
@Test
public void testDisplayBooleanNotDML() {
- System.out.println("MLContextTest - display boolean 'not' DML");
+ LOG.debug("MLContextTest - display boolean 'not' DML");
String s = "print(!b);";
Script script = dml(s).in("b", true);
- setExpectedStdOut("FALSE");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("FALSE"));
}
@Test
public void testDisplayIntegerAddDML() {
- System.out.println("MLContextTest - display integer add DML");
+ LOG.debug("MLContextTest - display integer add DML");
String s = "print(i+j);";
Script script = dml(s).in("i", 5).in("j", 6);
- setExpectedStdOut("11");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("11"));
}
@Test
public void testDisplayStringConcatenationDML() {
- System.out.println("MLContextTest - display string concatenation DML");
+ LOG.debug("MLContextTest - display string concatenation DML");
String s = "print(str1+str2);";
Script script = dml(s).in("str1", "hello").in("str2", "goodbye");
- setExpectedStdOut("hellogoodbye");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("hellogoodbye"));
}
@Test
public void testDisplayDoubleAddDML() {
- System.out.println("MLContextTest - display double add DML");
+ LOG.debug("MLContextTest - display double add DML");
String s = "print(i+j);";
Script script = dml(s).in("i", 5.1).in("j", 6.2);
- setExpectedStdOut("11.3");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("11.3"));
}
@Test
public void testPrintFormattingStringSubstitution() {
- System.out.println("MLContextTest - print formatting string substitution");
+ LOG.debug("MLContextTest - print formatting string substitution");
Script script = dml("print('hello %s', 'world');");
- setExpectedStdOut("hello world");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("hello world"));
}
@Test
public void testPrintFormattingStringSubstitutions() {
- System.out.println("MLContextTest - print formatting string substitutions");
+ LOG.debug("MLContextTest - print formatting string substitutions");
Script script = dml("print('%s %s', 'hello', 'world');");
- setExpectedStdOut("hello world");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("hello world"));
}
@Test
public void testPrintFormattingStringSubstitutionAlignment() {
- System.out.println("MLContextTest - print formatting string substitution alignment");
+ LOG.debug("MLContextTest - print formatting string substitution alignment");
Script script = dml("print(\"'%10s' '%-10s'\", \"hello\", \"world\");");
- setExpectedStdOut("' hello' 'world '");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("' hello' 'world '"));
}
@Test
public void testPrintFormattingStringSubstitutionVariables() {
- System.out.println("MLContextTest - print formatting string substitution variables");
+ LOG.debug("MLContextTest - print formatting string substitution variables");
Script script = dml("a='hello'; b='world'; print('%s %s', a, b);");
- setExpectedStdOut("hello world");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("hello world"));
}
@Test
public void testPrintFormattingIntegerSubstitution() {
- System.out.println("MLContextTest - print formatting integer substitution");
+ LOG.debug("MLContextTest - print formatting integer substitution");
Script script = dml("print('int %d', 42);");
- setExpectedStdOut("int 42");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("int 42"));
}
@Test
public void testPrintFormattingIntegerSubstitutions() {
- System.out.println("MLContextTest - print formatting integer substitutions");
+ LOG.debug("MLContextTest - print formatting integer substitutions");
Script script = dml("print('%d %d', 42, 43);");
- setExpectedStdOut("42 43");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("42 43"));
}
@Test
public void testPrintFormattingIntegerSubstitutionAlignment() {
- System.out.println("MLContextTest - print formatting integer substitution alignment");
+ LOG.debug("MLContextTest - print formatting integer substitution alignment");
Script script = dml("print(\"'%10d' '%-10d'\", 42, 43);");
- setExpectedStdOut("' 42' '43 '");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("' 42' '43 '"));
}
@Test
public void testPrintFormattingIntegerSubstitutionVariables() {
- System.out.println("MLContextTest - print formatting integer substitution variables");
+ LOG.debug("MLContextTest - print formatting integer substitution variables");
Script script = dml("a=42; b=43; print('%d %d', a, b);");
- setExpectedStdOut("42 43");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("42 43"));
}
@Test
public void testPrintFormattingDoubleSubstitution() {
- System.out.println("MLContextTest - print formatting double substitution");
+ LOG.debug("MLContextTest - print formatting double substitution");
Script script = dml("print('double %f', 42.0);");
- setExpectedStdOut("double 42.000000");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("double 42.000000"));
}
@Test
public void testPrintFormattingDoubleSubstitutions() {
- System.out.println("MLContextTest - print formatting double substitutions");
+ LOG.debug("MLContextTest - print formatting double substitutions");
Script script = dml("print('%f %f', 42.42, 43.43);");
- setExpectedStdOut("42.420000 43.430000");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("42.420000 43.430000"));
}
@Test
public void testPrintFormattingDoubleSubstitutionAlignment() {
- System.out.println("MLContextTest - print formatting double substitution alignment");
+ LOG.debug("MLContextTest - print formatting double substitution alignment");
Script script = dml("print(\"'%10.2f' '%-10.2f'\", 42.53, 43.54);");
- setExpectedStdOut("' 42.53' '43.54 '");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("' 42.53' '43.54 '"));
}
@Test
public void testPrintFormattingDoubleSubstitutionVariables() {
- System.out.println("MLContextTest - print formatting double substitution variables");
+ LOG.debug("MLContextTest - print formatting double substitution variables");
Script script = dml("a=12.34; b=56.78; print('%f %f', a, b);");
- setExpectedStdOut("12.340000 56.780000");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("12.340000 56.780000"));
}
@Test
public void testPrintFormattingBooleanSubstitution() {
- System.out.println("MLContextTest - print formatting boolean substitution");
+ LOG.debug("MLContextTest - print formatting boolean substitution");
Script script = dml("print('boolean %b', TRUE);");
- setExpectedStdOut("boolean true");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("boolean true"));
}
@Test
public void testPrintFormattingBooleanSubstitutions() {
- System.out.println("MLContextTest - print formatting boolean substitutions");
+ LOG.debug("MLContextTest - print formatting boolean substitutions");
Script script = dml("print('%b %b', TRUE, FALSE);");
- setExpectedStdOut("true false");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("true false"));
}
@Test
public void testPrintFormattingBooleanSubstitutionAlignment() {
- System.out.println("MLContextTest - print formatting boolean substitution alignment");
+ LOG.debug("MLContextTest - print formatting boolean substitution alignment");
Script script = dml("print(\"'%10b' '%-10b'\", TRUE, FALSE);");
- setExpectedStdOut("' true' 'false '");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("' true' 'false '"));
}
@Test
public void testPrintFormattingBooleanSubstitutionVariables() {
- System.out.println("MLContextTest - print formatting boolean substitution variables");
+ LOG.debug("MLContextTest - print formatting boolean substitution variables");
Script script = dml("a=TRUE; b=FALSE; print('%b %b', a, b);");
- setExpectedStdOut("true false");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("true false"));
}
@Test
public void testPrintFormattingMultipleTypes() {
- System.out.println("MLContextTest - print formatting multiple types");
+ LOG.debug("MLContextTest - print formatting multiple types");
Script script = dml("a='hello'; b=3; c=4.5; d=TRUE; print('%s %d %f %b', a, b, c, d);");
- setExpectedStdOut("hello 3 4.500000 true");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("hello 3 4.500000 true"));
}
@Test
public void testPrintFormattingMultipleExpressions() {
- System.out.println("MLContextTest - print formatting multiple expressions");
+ LOG.debug("MLContextTest - print formatting multiple expressions");
Script script = dml(
- "a='hello'; b='goodbye'; c=4; d=3; e=3.0; f=5.0; g=FALSE; print('%s %d %f %b', (a+b), (c-d), (e*f), !g);");
- setExpectedStdOut("hellogoodbye 1 15.000000 true");
- ml.execute(script);
+ "a='hello'; b='goodbye'; c=4; d=3; e=3.0; f=5.0; g=FALSE; print('%s %d %f %b', (a+b), (c-d), (e*f), !g);");
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("hellogoodbye 1 15.000000 true"));
}
@Test
public void testPrintFormattingForLoop() {
- System.out.println("MLContextTest - print formatting for loop");
+ LOG.debug("MLContextTest - print formatting for loop");
Script script = dml("for (i in 1:3) { print('int value %d', i); }");
// check that one of the lines is returned
- setExpectedStdOut("int value 3");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("int value 3"));
}
@Test
public void testPrintFormattingParforLoop() {
- System.out.println("MLContextTest - print formatting parfor loop");
+ LOG.debug("MLContextTest - print formatting parfor loop");
Script script = dml("parfor (i in 1:3) { print('int value %d', i); }");
// check that one of the lines is returned
- setExpectedStdOut("int value 3");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("int value 3"));
}
@Test
public void testPrintFormattingForLoopMultiply() {
- System.out.println("MLContextTest - print formatting for loop multiply");
+ LOG.debug("MLContextTest - print formatting for loop multiply");
Script script = dml("a = 5.0; for (i in 1:3) { print('%d %f', i, a * i); }");
// check that one of the lines is returned
- setExpectedStdOut("3 15.000000");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("3 15.000000"));
}
@Test
@@ -1517,95 +1524,95 @@
@Test
public void testInputVariablesAddLongsDML() {
- System.out.println("MLContextTest - input variables add longs DML");
+ LOG.debug("MLContextTest - input variables add longs DML");
String s = "print('x + y = ' + (x + y));";
Script script = dml(s).in("x", 3L).in("y", 4L);
- setExpectedStdOut("x + y = 7");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("x + y = 7"));
}
@Test
public void testInputVariablesAddFloatsDML() {
- System.out.println("MLContextTest - input variables add floats DML");
+ LOG.debug("MLContextTest - input variables add floats DML");
String s = "print('x + y = ' + (x + y));";
Script script = dml(s).in("x", 3F).in("y", 4F);
- setExpectedStdOut("x + y = 7.0");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("x + y = 7.0"));
}
@Test
public void testFunctionNoReturnValueDML() {
- System.out.println("MLContextTest - function with no return value DML");
+ LOG.debug("MLContextTest - function with no return value DML");
String s = "hello=function(){print('no return value')}\nhello();";
Script script = dml(s);
- setExpectedStdOut("no return value");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("no return value"));
}
@Test
public void testFunctionNoReturnValueForceFunctionCallDML() {
- System.out.println("MLContextTest - function with no return value, force function call DML");
+ LOG.debug("MLContextTest - function with no return value, force function call DML");
String s = "hello=function(){\nwhile(FALSE){};\nprint('no return value, force function call');\n}\nhello();";
Script script = dml(s);
- setExpectedStdOut("no return value, force function call");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("no return value, force function call"));
}
@Test
public void testFunctionReturnValueDML() {
- System.out.println("MLContextTest - function with return value DML");
+ LOG.debug("MLContextTest - function with return value DML");
String s = "hello=function()return(string s){s='return value'}\na=hello();\nprint(a);";
Script script = dml(s);
- setExpectedStdOut("return value");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("return value"));
}
@Test
public void testFunctionTwoReturnValuesDML() {
- System.out.println("MLContextTest - function with two return values DML");
+ LOG.debug("MLContextTest - function with two return values DML");
String s = "hello=function()return(string s1,string s2){s1='return'; s2='values'}\n[a,b]=hello();\nprint(a+' '+b);";
Script script = dml(s);
- setExpectedStdOut("return values");
- ml.execute(script);
+ String out = executeAndCaptureStdOut(ml, script).getRight();
+ assertTrue(out.contains("return values"));
}
@Test
public void testOutputListDML() {
- System.out.println("MLContextTest - output specified as List DML");
+ LOG.debug("MLContextTest - output specified as List DML");
List<String> outputs = Arrays.asList("x", "y");
Script script = dml("a=1;x=a+1;y=x+1").out(outputs);
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
Assert.assertEquals(2, results.getLong("x"));
Assert.assertEquals(3, results.getLong("y"));
}
- @SuppressWarnings({ "unchecked", "rawtypes" })
+ @SuppressWarnings({"unchecked", "rawtypes"})
@Test
public void testOutputScalaSeqDML() {
- System.out.println("MLContextTest - output specified as Scala Seq DML");
+ LOG.debug("MLContextTest - output specified as Scala Seq DML");
List outputs = Arrays.asList("x", "y");
Seq seq = JavaConversions.asScalaBuffer(outputs).toSeq();
Script script = dml("a=1;x=a+1;y=x+1").out(seq);
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
Assert.assertEquals(2, results.getLong("x"));
Assert.assertEquals(3, results.getLong("y"));
}
@Test
public void testOutputDataFrameOfVectorsDML() {
- System.out.println("MLContextTest - output DataFrame of vectors DML");
+ LOG.debug("MLContextTest - output DataFrame of vectors DML");
String s = "m=matrix('1 2 3 4',rows=2,cols=2);";
Script script = dml(s).out("m");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
Dataset<Row> df = results.getDataFrame("m", true);
Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN);
@@ -1623,21 +1630,21 @@
Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
Vector v1 = (DenseVector) row1.get(1);
double[] arr1 = v1.toArray();
- Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, arr1, 0.0);
+ Assert.assertArrayEquals(new double[] {1.0, 2.0}, arr1, 0.0);
Row row2 = list.get(1);
Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
Vector v2 = (DenseVector) row2.get(1);
double[] arr2 = v2.toArray();
- Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, arr2, 0.0);
+ Assert.assertArrayEquals(new double[] {3.0, 4.0}, arr2, 0.0);
}
@Test
public void testOutputDoubleArrayFromMatrixDML() {
- System.out.println("MLContextTest - output double array from matrix DML");
+ LOG.debug("MLContextTest - output double array from matrix DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
- double[][] matrix = ml.execute(dml(s).out("M")).getMatrix("M").to2DDoubleArray();
+ double[][] matrix = executeAndCaptureStdOut(dml(s).out("M")).getLeft().getMatrix("M").to2DDoubleArray();
Assert.assertEquals(1.0, matrix[0][0], 0);
Assert.assertEquals(2.0, matrix[0][1], 0);
Assert.assertEquals(3.0, matrix[1][0], 0);
@@ -1646,11 +1653,11 @@
@Test
public void testOutputDataFrameFromMatrixDML() {
- System.out.println("MLContextTest - output DataFrame from matrix DML");
+ LOG.debug("MLContextTest - output DataFrame from matrix DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
Script script = dml(s).out("M");
- Dataset<Row> df = ml.execute(script).getMatrix("M").toDF();
+ Dataset<Row> df = executeAndCaptureStdOut(script).getLeft().getMatrix("M").toDF();
Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN);
List<Row> list = sortedDF.collectAsList();
Row row1 = list.get(0);
@@ -1666,11 +1673,11 @@
@Test
public void testOutputDataFrameDoublesNoIDColumnFromMatrixDML() {
- System.out.println("MLContextTest - output DataFrame of doubles with no ID column from matrix DML");
+ LOG.debug("MLContextTest - output DataFrame of doubles with no ID column from matrix DML");
String s = "M = matrix('1 2 3 4', rows=1, cols=4);";
Script script = dml(s).out("M");
- Dataset<Row> df = ml.execute(script).getMatrix("M").toDFDoubleNoIDColumn();
+ Dataset<Row> df = executeAndCaptureStdOut(script).getLeft().getMatrix("M").toDFDoubleNoIDColumn();
List<Row> list = df.collectAsList();
Row row = list.get(0);
@@ -1682,11 +1689,11 @@
@Test
public void testOutputDataFrameDoublesWithIDColumnFromMatrixDML() {
- System.out.println("MLContextTest - output DataFrame of doubles with ID column from matrix DML");
+ LOG.debug("MLContextTest - output DataFrame of doubles with ID column from matrix DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
Script script = dml(s).out("M");
- Dataset<Row> df = ml.execute(script).getMatrix("M").toDFDoubleWithIDColumn();
+ Dataset<Row> df = executeAndCaptureStdOut(script).getLeft().getMatrix("M").toDFDoubleWithIDColumn();
Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN);
List<Row> list = sortedDF.collectAsList();
@@ -1703,49 +1710,50 @@
@Test
public void testOutputDataFrameVectorsNoIDColumnFromMatrixDML() {
- System.out.println("MLContextTest - output DataFrame of vectors with no ID column from matrix DML");
+ LOG.debug("MLContextTest - output DataFrame of vectors with no ID column from matrix DML");
String s = "M = matrix('1 2 3 4', rows=1, cols=4);";
Script script = dml(s).out("M");
- Dataset<Row> df = ml.execute(script).getMatrix("M").toDFVectorNoIDColumn();
+ Dataset<Row> df = executeAndCaptureStdOut(script).getLeft().getMatrix("M").toDFVectorNoIDColumn();
List<Row> list = df.collectAsList();
Row row = list.get(0);
- Assert.assertArrayEquals(new double[] { 1.0, 2.0, 3.0, 4.0 }, ((Vector) row.get(0)).toArray(), 0.0);
+ Assert.assertArrayEquals(new double[] {1.0, 2.0, 3.0, 4.0}, ((Vector) row.get(0)).toArray(), 0.0);
}
@Test
public void testOutputDataFrameVectorsWithIDColumnFromMatrixDML() {
- System.out.println("MLContextTest - output DataFrame of vectors with ID column from matrix DML");
+ LOG.debug("MLContextTest - output DataFrame of vectors with ID column from matrix DML");
String s = "M = matrix('1 2 3 4', rows=1, cols=4);";
Script script = dml(s).out("M");
- Dataset<Row> df = ml.execute(script).getMatrix("M").toDFVectorWithIDColumn();
+ Dataset<Row> df = executeAndCaptureStdOut(script).getLeft().getMatrix("M").toDFVectorWithIDColumn();
List<Row> list = df.collectAsList();
Row row = list.get(0);
Assert.assertEquals(1.0, row.getDouble(0), 0.0);
- Assert.assertArrayEquals(new double[] { 1.0, 2.0, 3.0, 4.0 }, ((Vector) row.get(1)).toArray(), 0.0);
+ Assert.assertArrayEquals(new double[] {1.0, 2.0, 3.0, 4.0}, ((Vector) row.get(1)).toArray(), 0.0);
}
@Test
public void testOutputJavaRDDStringCSVFromMatrixDML() {
- System.out.println("MLContextTest - output Java RDD String CSV from matrix DML");
+ LOG.debug("MLContextTest - output Java RDD String CSV from matrix DML");
String s = "M = matrix('1 2 3 4', rows=1, cols=4);";
Script script = dml(s).out("M");
- JavaRDD<String> javaRDDStringCSV = ml.execute(script).getMatrix("M").toJavaRDDStringCSV();
+ JavaRDD<String> javaRDDStringCSV = executeAndCaptureStdOut(script).getLeft().getMatrix("M")
+ .toJavaRDDStringCSV();
List<String> lines = javaRDDStringCSV.collect();
Assert.assertEquals("1.0,2.0,3.0,4.0", lines.get(0));
}
@Test
public void testOutputJavaRDDStringIJVFromMatrixDML() {
- System.out.println("MLContextTest - output Java RDD String IJV from matrix DML");
+ LOG.debug("MLContextTest - output Java RDD String IJV from matrix DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
Script script = dml(s).out("M");
- MLResults results = ml.execute(script);
+ MLResults results = executeAndCaptureStdOut(script).getLeft();
JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("M");
List<String> lines = javaRDDStringIJV.sortBy(row -> row, true, 1).collect();
Assert.assertEquals("1 1 1.0", lines.get(0));
@@ -1756,22 +1764,22 @@
@Test
public void testOutputRDDStringCSVFromMatrixDML() {
- System.out.println("MLContextTest - output RDD String CSV from matrix DML");
+ LOG.debug("MLContextTest - output RDD String CSV from matrix DML");
String s = "M = matrix('1 2 3 4', rows=1, cols=4);";
Script script = dml(s).out("M");
- RDD<String> rddStringCSV = ml.execute(script).getMatrix("M").toRDDStringCSV();
+ RDD<String> rddStringCSV = executeAndCaptureStdOut(script).getLeft().getMatrix("M").toRDDStringCSV();
Iterator<String> iterator = rddStringCSV.toLocalIterator();
Assert.assertEquals("1.0,2.0,3.0,4.0", iterator.next());
}
@Test
public void testOutputRDDStringIJVFromMatrixDML() {
- System.out.println("MLContextTest - output RDD String IJV from matrix DML");
+ LOG.debug("MLContextTest - output RDD String IJV from matrix DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
Script script = dml(s).out("M");
- RDD<String> rddStringIJV = ml.execute(script).getMatrix("M").toRDDStringIJV();
+ RDD<String> rddStringIJV = executeAndCaptureStdOut(script).getLeft().getMatrix("M").toRDDStringIJV();
String[] rows = (String[]) rddStringIJV.collect();
Arrays.sort(rows);
Assert.assertEquals("1 1 1.0", rows[0]);
@@ -1782,7 +1790,7 @@
@Test
public void testMLContextVersionMessage() {
- System.out.println("MLContextTest - version message");
+ LOG.debug("MLContextTest - version message");
String version = ml.version();
// not available until jar built
@@ -1791,7 +1799,7 @@
@Test
public void testMLContextBuildTimeMessage() {
- System.out.println("MLContextTest - build time message");
+ LOG.debug("MLContextTest - build time message");
String buildTime = ml.buildTime();
// not available until jar built
@@ -1802,12 +1810,12 @@
public void testMLContextCreateAndClose() {
// MLContext created by the @BeforeClass method in MLContextTestBase
// MLContext closed by the @AfterClass method in MLContextTestBase
- System.out.println("MLContextTest - create MLContext and close (without script execution)");
+ LOG.debug("MLContextTest - create MLContext and close (without script execution)");
}
@Test
public void testDataFrameToBinaryBlocks() {
- System.out.println("MLContextTest - DataFrame to binary blocks");
+ LOG.debug("MLContextTest - DataFrame to binary blocks");
List<String> list = new ArrayList<>();
list.add("1,2,3");
@@ -1824,20 +1832,20 @@
Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);
JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks = MLContextConversionUtil
- .dataFrameToMatrixBinaryBlocks(dataFrame);
+ .dataFrameToMatrixBinaryBlocks(dataFrame);
Tuple2<MatrixIndexes, MatrixBlock> first = binaryBlocks.first();
MatrixBlock mb = first._2();
double[][] matrix = DataConverter.convertToDoubleMatrix(mb);
- Assert.assertArrayEquals(new double[] { 1.0, 2.0, 3.0 }, matrix[0], 0.0);
- Assert.assertArrayEquals(new double[] { 4.0, 5.0, 6.0 }, matrix[1], 0.0);
- Assert.assertArrayEquals(new double[] { 7.0, 8.0, 9.0 }, matrix[2], 0.0);
+ Assert.assertArrayEquals(new double[] {1.0, 2.0, 3.0}, matrix[0], 0.0);
+ Assert.assertArrayEquals(new double[] {4.0, 5.0, 6.0}, matrix[1], 0.0);
+ Assert.assertArrayEquals(new double[] {7.0, 8.0, 9.0}, matrix[2], 0.0);
}
@Test
public void testGetTuple1DML() {
- System.out.println("MLContextTest - Get Tuple1<Matrix> DML");
+ LOG.debug("MLContextTest - Get Tuple1<Matrix> DML");
JavaRDD<String> javaRddString = sc
- .parallelize(Stream.of("1,2,3", "4,5,6", "7,8,9").collect(Collectors.toList()));
+ .parallelize(Stream.of("1,2,3", "4,5,6", "7,8,9").collect(Collectors.toList()));
JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
List<StructField> fields = new ArrayList<>();
fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
@@ -1847,7 +1855,7 @@
Dataset<Row> df = spark.createDataFrame(javaRddRow, schema);
Script script = dml("N=M*2").in("M", df).out("N");
- Tuple1<Matrix> tuple = ml.execute(script).getTuple("N");
+ Tuple1<Matrix> tuple = executeAndCaptureStdOut(script).getLeft().getTuple("N");
double[][] n = tuple._1().to2DDoubleArray();
Assert.assertEquals(2.0, n[0][0], 0);
Assert.assertEquals(4.0, n[0][1], 0);
@@ -1862,25 +1870,25 @@
@Test
public void testGetTuple2DML() {
- System.out.println("MLContextTest - Get Tuple2<Matrix,Double> DML");
+ LOG.debug("MLContextTest - Get Tuple2<Matrix,Double> DML");
- double[][] m = new double[][] { { 1, 2 }, { 3, 4 } };
+ double[][] m = new double[][] {{1, 2}, {3, 4}};
Script script = dml("N=M*2;s=sum(N)").in("M", m).out("N", "s");
- Tuple2<Matrix, Double> tuple = ml.execute(script).getTuple("N", "s");
+ Tuple2<Matrix, Double> tuple = executeAndCaptureStdOut(script).getLeft().getTuple("N", "s");
double[][] n = tuple._1().to2DDoubleArray();
double s = tuple._2();
- Assert.assertArrayEquals(new double[] { 2, 4 }, n[0], 0.0);
- Assert.assertArrayEquals(new double[] { 6, 8 }, n[1], 0.0);
+ Assert.assertArrayEquals(new double[] {2, 4}, n[0], 0.0);
+ Assert.assertArrayEquals(new double[] {6, 8}, n[1], 0.0);
Assert.assertEquals(20.0, s, 0.0);
}
@Test
public void testGetTuple3DML() {
- System.out.println("MLContextTest - Get Tuple3<Long,Double,Boolean> DML");
+ LOG.debug("MLContextTest - Get Tuple3<Long,Double,Boolean> DML");
Script script = dml("a=1+2;b=a+0.5;c=TRUE;").out("a", "b", "c");
- Tuple3<Long, Double, Boolean> tuple = ml.execute(script).getTuple("a", "b", "c");
+ Tuple3<Long, Double, Boolean> tuple = executeAndCaptureStdOut(script).getLeft().getTuple("a", "b", "c");
long a = tuple._1();
double b = tuple._2();
boolean c = tuple._3();
@@ -1891,10 +1899,11 @@
@Test
public void testGetTuple4DML() {
- System.out.println("MLContextTest - Get Tuple4<Long,Double,Boolean,String> DML");
+ LOG.debug("MLContextTest - Get Tuple4<Long,Double,Boolean,String> DML");
Script script = dml("a=1+2;b=a+0.5;c=TRUE;d=\"yes it's \"+c").out("a", "b", "c", "d");
- Tuple4<Long, Double, Boolean, String> tuple = ml.execute(script).getTuple("a", "b", "c", "d");
+ Tuple4<Long, Double, Boolean, String> tuple = executeAndCaptureStdOut(script).getLeft()
+ .getTuple("a", "b", "c", "d");
long a = tuple._1();
double b = tuple._2();
boolean c = tuple._3();
diff --git a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTestBase.java b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTestBase.java
index 602628f..ce1abf3 100644
--- a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTestBase.java
+++ b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTestBase.java
@@ -16,23 +16,33 @@
* specific language governing permissions and limitations
* under the License.
*/
-
+
package org.apache.sysds.test.functions.mlcontext;
+import static org.junit.Assert.fail;
+
+import java.io.ByteArrayOutputStream;
+import java.io.PrintStream;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
+import org.apache.sysds.api.mlcontext.MLContext;
+import org.apache.sysds.api.mlcontext.MLContextUtil;
+import org.apache.sysds.api.mlcontext.MLResults;
+import org.apache.sysds.api.mlcontext.Script;
+import org.apache.sysds.api.mlcontext.ScriptExecutor;
+import org.apache.sysds.test.AutomatedTestBase;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.BeforeClass;
-import org.apache.sysds.api.mlcontext.MLContext;
-import org.apache.sysds.api.mlcontext.MLContextUtil;
-import org.apache.sysds.test.AutomatedTestBase;
/**
* Abstract class that can be used for MLContext tests.
* <p>
- * Note that if using the setUp() method of MLContextTestBase, the test directory
- * and test name can be specified if needed in the subclass.
+ * Note that if using the setUp() method of MLContextTestBase, the test directory and test name can be specified if
+ * needed in the subclass.
* <p>
*
* Example:
@@ -87,4 +97,94 @@
ml.close();
ml = null;
}
+
+ public static Pair<MLResults, String> executeAndCaptureStdOut(Script script){
+ ByteArrayOutputStream buff = new ByteArrayOutputStream();
+ PrintStream ps = new PrintStream(buff);
+ PrintStream old = System.out;
+ System.setOut(ps);
+ MLResults res = safeExecute(buff, script, null);
+ System.out.flush();
+ System.setOut(old);
+
+ return new ImmutablePair<>(res, buff.toString());
+ }
+
+ public static Pair<MLResults, String> executeAndCaptureStdOut(Script script, Class<?> expectedException){
+ if(expectedException == null){
+ return executeAndCaptureStdOut(script);
+ }
+
+ ByteArrayOutputStream buff = new ByteArrayOutputStream();
+ PrintStream ps = new PrintStream(buff);
+ PrintStream old = System.out;
+ System.setOut(ps);
+ MLResults res= unsafeExecute(script, null, expectedException);
+ System.out.flush();
+ System.setOut(old);
+
+ return new ImmutablePair<>(res, buff.toString());
+ }
+
+
+ public static Pair<MLResults, String> executeAndCaptureStdOut(MLContext ml, Script script){
+ ByteArrayOutputStream buff = new ByteArrayOutputStream();
+ PrintStream ps = new PrintStream(buff);
+ PrintStream old = System.out;
+ System.setOut(ps);
+ MLResults res = safeExecute(buff, script, null);
+ System.out.flush();
+ System.setOut(old);
+
+ return new ImmutablePair<>(res, buff.toString());
+ }
+
+ public static Pair<MLResults, String> executeAndCaptureStdOut(MLContext ml, Script script, ScriptExecutor sce){
+ ByteArrayOutputStream buff = new ByteArrayOutputStream();
+ PrintStream ps = new PrintStream(buff);
+ PrintStream old = System.out;
+ System.setOut(ps);
+ MLResults res = safeExecute(buff, script,sce);
+ System.out.flush();
+ System.setOut(old);
+
+ return new ImmutablePair<>(res, buff.toString());
+ }
+
+ private static MLResults safeExecute(ByteArrayOutputStream buff, Script script, ScriptExecutor sce){
+ try {
+
+ MLResults res = sce == null ? ml.execute(script): ml.execute(script,sce);
+ return res;
+ }
+ catch(Exception e) {
+ StringBuilder errorMessage = new StringBuilder();
+ errorMessage.append("\nfailed to run script: ");
+ errorMessage.append("\nStandard Out:");
+ errorMessage.append("\n" + buff);
+ errorMessage.append("\nStackTrace:");
+ errorMessage.append(AutomatedTestBase.getStackTraceString(e, 0));
+ fail(errorMessage.toString());
+ }
+ return null;
+ }
+
+ private static MLResults unsafeExecute(Script script, ScriptExecutor sce, Class<?> expectedException){
+ try {
+
+ MLResults res = sce == null ? ml.execute(script): ml.execute(script, sce);
+ return res;
+ }
+ catch(Exception e) {
+ if(!(e.getClass().equals(expectedException))){
+
+ StringBuilder errorMessage = new StringBuilder();
+ errorMessage.append("\nfailed to run script: ");
+ errorMessage.append("\nStackTrace:");
+ errorMessage.append(AutomatedTestBase.getStackTraceString(e, 0));
+ fail(errorMessage.toString());
+ }
+ }
+ return null;
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/functions/nary/NaryListTest.java b/src/test/java/org/apache/sysds/test/functions/nary/NaryListTest.java
new file mode 100644
index 0000000..34cfe94
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/nary/NaryListTest.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.nary;
+
+import static org.junit.Assert.assertTrue;
+
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class NaryListTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "NaryList";
+ private final static String TEST_DIR = "functions/nary/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + NaryListTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
+ }
+
+ @Test
+ public void test(){
+ TestConfiguration config = getAndLoadTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config);
+
+ String RI_HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-stats" };
+
+ String out = runTest(true, false, null, -1).toString();
+ assertTrue( "Output: " + out, out.contains("[hi, Im, a, list]"));
+ }
+
+}
\ No newline at end of file
diff --git a/src/test/java/org/apache/sysds/test/functions/nary/NaryListTestAdvanced.java b/src/test/java/org/apache/sysds/test/functions/nary/NaryListTestAdvanced.java
new file mode 100644
index 0000000..f4feb49
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/nary/NaryListTestAdvanced.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.nary;
+
+import static org.junit.Assert.assertTrue;
+
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class NaryListTestAdvanced extends AutomatedTestBase {
+ private final static String TEST_NAME = "NaryListAdvanced";
+ private final static String TEST_DIR = "functions/nary/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + NaryListTestAdvanced.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
+ }
+
+ @Test
+ public void test() {
+ TestConfiguration config = getAndLoadTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config);
+
+ String RI_HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {""};
+
+ String out = runTest(true, false, null, -1).toString();
+ assertTrue("Output: " + out,
+ out.contains(String.join("\n",
+ "[1, Im, ",
+ "Matrix:",
+ "1.000 1.000",
+ "1.000 1.000",
+ ", ",
+ "# FRAME: nrow = 2, ncol = 2",
+ "# C1 C2",
+ "# FP64 FP64",
+ "1.000 1.000",
+ "1.000 1.000",
+ "]")));
+ }
+
+}
\ No newline at end of file
diff --git a/src/test/resources/log4j.properties b/src/test/resources/log4j.properties
index f0e65d2..944f3fd 100644
--- a/src/test/resources/log4j.properties
+++ b/src/test/resources/log4j.properties
@@ -24,7 +24,7 @@
log4j.logger.org.apache.sysds.api.DMLScript=OFF
log4j.logger.org.apache.sysds.test=INFO
log4j.logger.org.apache.sysds.test.AutomatedTestBase=ERROR
-log4j.logger.org.apache.sysds=ERROR
+log4j.logger.org.apache.sysds=WARN
log4j.logger.org.apache.spark=OFF
log4j.logger.org.apache.hadoop=OFF
diff --git a/src/test/scripts/functions/io/csv/ReadCSVTest_4.dml b/src/test/scripts/functions/io/csv/ReadCSVTest_4.dml
new file mode 100644
index 0000000..4729b2a
--- /dev/null
+++ b/src/test/scripts/functions/io/csv/ReadCSVTest_4.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# DML script that tests read csv
+
+A = read($1, format="csv", header=TRUE, naStrings= ["NA"] );
+x = colSums(A);
+print(toString(x));
+
diff --git a/src/test/scripts/functions/io/csv/ReadCSVTest_5.dml b/src/test/scripts/functions/io/csv/ReadCSVTest_5.dml
new file mode 100644
index 0000000..2ffac90
--- /dev/null
+++ b/src/test/scripts/functions/io/csv/ReadCSVTest_5.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1, value_type="double", format="csv", header=TRUE, naStrings= ["NA", "null", "fake", "nothing", "nan"] );
+x = colSums(A);
+print(toString(x));
+
diff --git a/src/test/scripts/functions/io/csv/ReadCSVTest_6.dml b/src/test/scripts/functions/io/csv/ReadCSVTest_6.dml
new file mode 100644
index 0000000..6aab6a5
--- /dev/null
+++ b/src/test/scripts/functions/io/csv/ReadCSVTest_6.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+x = colSums(A);
+print(toString(x));
+
diff --git a/src/test/scripts/functions/io/csv/ReadFrameTest_1.dml b/src/test/scripts/functions/io/csv/ReadFrameTest_1.dml
new file mode 100644
index 0000000..0025a3f
--- /dev/null
+++ b/src/test/scripts/functions/io/csv/ReadFrameTest_1.dml
@@ -0,0 +1,23 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+x = read($1);
+print(toString(x));
\ No newline at end of file
diff --git a/src/test/scripts/functions/io/csv/ReadFrameTest_2.dml b/src/test/scripts/functions/io/csv/ReadFrameTest_2.dml
new file mode 100644
index 0000000..e34b2a9
--- /dev/null
+++ b/src/test/scripts/functions/io/csv/ReadFrameTest_2.dml
@@ -0,0 +1,23 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+x = read($1, data_type="frame", format="csv", header=TRUE );
+print(toString(x));
\ No newline at end of file
diff --git a/src/test/scripts/functions/io/csv/ReadFrameTest_3.dml b/src/test/scripts/functions/io/csv/ReadFrameTest_3.dml
new file mode 100644
index 0000000..0025a3f
--- /dev/null
+++ b/src/test/scripts/functions/io/csv/ReadFrameTest_3.dml
@@ -0,0 +1,23 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+x = read($1);
+print(toString(x));
\ No newline at end of file
diff --git a/src/test/scripts/functions/io/csv/in/frame_1.csv b/src/test/scripts/functions/io/csv/in/frame_1.csv
new file mode 100644
index 0000000..1d55f2b
--- /dev/null
+++ b/src/test/scripts/functions/io/csv/in/frame_1.csv
@@ -0,0 +1,3 @@
+c1,c2,c4,c5
+hi,1,five,columbia
+goodbye,2,four,new york
\ No newline at end of file
diff --git a/src/test/scripts/functions/io/csv/in/frame_1.csv.mtd b/src/test/scripts/functions/io/csv/in/frame_1.csv.mtd
new file mode 100644
index 0000000..12c28de
--- /dev/null
+++ b/src/test/scripts/functions/io/csv/in/frame_1.csv.mtd
@@ -0,0 +1,8 @@
+{
+ "data_type": "frame",
+ "format": "csv",
+ "header": true,
+ "description": {
+ "author": "SystemDS"
+ }
+}
\ No newline at end of file
diff --git a/src/test/scripts/functions/io/csv/in/frame_2.csv b/src/test/scripts/functions/io/csv/in/frame_2.csv
new file mode 100644
index 0000000..1d55f2b
--- /dev/null
+++ b/src/test/scripts/functions/io/csv/in/frame_2.csv
@@ -0,0 +1,3 @@
+c1,c2,c4,c5
+hi,1,five,columbia
+goodbye,2,four,new york
\ No newline at end of file
diff --git a/src/test/scripts/functions/io/csv/in/frame_3.csv b/src/test/scripts/functions/io/csv/in/frame_3.csv
new file mode 100644
index 0000000..e1acd95
--- /dev/null
+++ b/src/test/scripts/functions/io/csv/in/frame_3.csv
@@ -0,0 +1,3 @@
+c1,c2,c4,c5
+hi,1,five,
+goodbye,2,four,new york
\ No newline at end of file
diff --git a/src/test/scripts/functions/io/csv/in/frame_3.csv.mtd b/src/test/scripts/functions/io/csv/in/frame_3.csv.mtd
new file mode 100644
index 0000000..fcdad84
--- /dev/null
+++ b/src/test/scripts/functions/io/csv/in/frame_3.csv.mtd
@@ -0,0 +1,13 @@
+{
+ "data_type": "frame",
+ "format": "csv",
+ "header": true,
+ "description": {
+ "author": "SystemDS"
+ },
+ "naStrings": [
+ "hi",
+ "goodbye",
+ ""
+ ]
+}
\ No newline at end of file
diff --git a/src/test/scripts/functions/io/csv/in/nan_integers_4.csv b/src/test/scripts/functions/io/csv/in/nan_integers_4.csv
new file mode 100644
index 0000000..dab37bd
--- /dev/null
+++ b/src/test/scripts/functions/io/csv/in/nan_integers_4.csv
@@ -0,0 +1,9 @@
+col1,col2,col3,col4
+NA,1,NA,NA
+NA,1,1,NA
+NA,1,NA,1
+NA,1,NA,NA
+1,1,NA,NA
+NA,1,NA,NA
+NA,1,NA,NA
+NA,1,NA,NA
\ No newline at end of file
diff --git a/src/test/scripts/functions/io/csv/in/nan_integers_5.csv b/src/test/scripts/functions/io/csv/in/nan_integers_5.csv
new file mode 100644
index 0000000..b9c411b
--- /dev/null
+++ b/src/test/scripts/functions/io/csv/in/nan_integers_5.csv
@@ -0,0 +1,9 @@
+col1,col2,col3,col4
+NA,1,fake,NA
+nan,1,1,NA
+NA,1,NA,1
+NA,1,NA,nan
+1,1,null,NA
+NA,1,NA,fake
+NA,1,NA,NA
+NA, 1, fake, nothing
\ No newline at end of file
diff --git a/src/test/scripts/functions/io/csv/in/nan_integers_6.csv b/src/test/scripts/functions/io/csv/in/nan_integers_6.csv
new file mode 100644
index 0000000..b9c411b
--- /dev/null
+++ b/src/test/scripts/functions/io/csv/in/nan_integers_6.csv
@@ -0,0 +1,9 @@
+col1,col2,col3,col4
+NA,1,fake,NA
+nan,1,1,NA
+NA,1,NA,1
+NA,1,NA,nan
+1,1,null,NA
+NA,1,NA,fake
+NA,1,NA,NA
+NA, 1, fake, nothing
\ No newline at end of file
diff --git a/src/test/scripts/functions/io/csv/in/nan_integers_6.csv.mtd b/src/test/scripts/functions/io/csv/in/nan_integers_6.csv.mtd
new file mode 100644
index 0000000..e463b58
--- /dev/null
+++ b/src/test/scripts/functions/io/csv/in/nan_integers_6.csv.mtd
@@ -0,0 +1,15 @@
+{
+ "data_type": "matrix",
+ "format": "csv",
+ "header": true,
+ "description": {
+ "author": "SystemDS"
+ },
+ "naStrings": [
+ "NA",
+ "null",
+ "fake",
+ "nothing",
+ "nan"
+ ]
+}
\ No newline at end of file
diff --git a/src/test/scripts/functions/io/csv/in/transfusion_1.csv.mtd b/src/test/scripts/functions/io/csv/in/transfusion_1.csv.mtd
index 6df936e..40387b9 100644
--- a/src/test/scripts/functions/io/csv/in/transfusion_1.csv.mtd
+++ b/src/test/scripts/functions/io/csv/in/transfusion_1.csv.mtd
@@ -2,5 +2,5 @@
"data_type": "matrix"
,"format": "csv"
,"header": true
- ,"description": { "author": "SystemML" }
+ ,"description": { "author": "SystemDS" }
}
\ No newline at end of file
diff --git a/src/test/scripts/functions/io/csv/in/transfusion_3.csv.mtd b/src/test/scripts/functions/io/csv/in/transfusion_3.csv.mtd
index d7712dd..bc2b47f 100644
--- a/src/test/scripts/functions/io/csv/in/transfusion_3.csv.mtd
+++ b/src/test/scripts/functions/io/csv/in/transfusion_3.csv.mtd
@@ -5,5 +5,5 @@
,"sep": ","
,"rows": 748
,"cols": 5
- ,"description": { "author": "SystemML" }
+ ,"description": { "author": "SystemDS" }
}
\ No newline at end of file
diff --git a/src/test/scripts/functions/nary/NaryList.dml b/src/test/scripts/functions/nary/NaryList.dml
new file mode 100644
index 0000000..9c5b784
--- /dev/null
+++ b/src/test/scripts/functions/nary/NaryList.dml
@@ -0,0 +1,23 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+x = list("hi", "Im", "a", "list")
+print(toString(x))
\ No newline at end of file
diff --git a/src/test/scripts/functions/nary/NaryListAdvanced.dml b/src/test/scripts/functions/nary/NaryListAdvanced.dml
new file mode 100644
index 0000000..da89856
--- /dev/null
+++ b/src/test/scripts/functions/nary/NaryListAdvanced.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+A = matrix(1, rows=2,cols=2)
+Af = as.frame(A)
+x = list(1, "Im", A, Af)
+print(toString(x))
\ No newline at end of file