[SYSTEMDS-2764] Frame constructor and data-gen operations
DIA project WS2020/21.
Closes #1132.
diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index c76cd1c..f03feea 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -448,7 +448,7 @@
}
public enum OpOpDG {
- RAND, SEQ, SINIT, SAMPLE, TIME
+ RAND, SEQ, FRAMEINIT, SINIT, SAMPLE, TIME
}
public enum OpOpData {
diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index 6e960e7..c57fbc8 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -1357,7 +1357,7 @@
DataGenOp d = (DataGenOp) hop;
HashMap<String,Integer> params = d.getParamIndexMap();
if ( d.getOp() == OpOpDG.RAND || d.getOp()==OpOpDG.SINIT
- || d.getOp() == OpOpDG.SAMPLE )
+ || d.getOp() == OpOpDG.SAMPLE || d.getOp() == OpOpDG.FRAMEINIT )
{
boolean initUnknown = !d.dimsKnown();
// TODO refresh tensor size information
diff --git a/src/main/java/org/apache/sysds/lops/DataGen.java b/src/main/java/org/apache/sysds/lops/DataGen.java
index ddc1a8a..7487a59 100644
--- a/src/main/java/org/apache/sysds/lops/DataGen.java
+++ b/src/main/java/org/apache/sysds/lops/DataGen.java
@@ -42,7 +42,8 @@
public static final String SINIT_OPCODE = "sinit"; //string initialize
public static final String SAMPLE_OPCODE = "sample"; //sample.int
public static final String TIME_OPCODE = "time"; //time
-
+ public static final String FRAME_OPCODE = "frame"; //time
+
private int _numThreads = 1;
/** base dir for rand input */
@@ -111,6 +112,8 @@
return getSampleInstructionCPSpark(output);
case TIME:
return getTimeInstructionCP(output);
+ case FRAMEINIT:
+ return getFrameInstructionCPSpark(output);
default:
throw new LopsException("Unknown data generation method: " + _op);
}
@@ -206,6 +209,76 @@
return sb.toString();
}
+ private String getFrameInstructionCPSpark(String output)
+ {
+ //sanity checks
+ if ( _op != OpOpDG.FRAMEINIT )
+ throw new LopsException("Invalid instruction generation for data generation method " + _op);
+ if( getInputs().size() != DataExpression.RAND_VALID_PARAM_NAMES.size() - 5 ) { // frame
+ throw new LopsException(printErrorLocation() + "Invalid number of operands ("
+ + getInputs().size() + ") for a frame operation");
+ }
+
+ StringBuilder sb = new StringBuilder();
+ sb.append( getExecType() );
+ sb.append( Lop.OPERAND_DELIMITOR );
+
+ sb.append(FRAME_OPCODE);
+ sb.append(OPERAND_DELIMITOR);
+
+ Lop iLop = _inputParams.get(DataExpression.RAND_DATA);
+ if ( iLop != null ) {
+ if(iLop instanceof Nary) {
+ for(Lop lop : iLop.getInputs()) {
+ sb.append(((Data)lop).getStringValue());
+ sb.append(DataExpression.DELIM_NA_STRING_SEP);
+ }
+ }
+ else if(iLop instanceof Data) {
+ sb.append(((Data)iLop).getStringValue());
+ }
+ }
+
+ sb.append(OPERAND_DELIMITOR);
+
+ iLop = _inputParams.get(DataExpression.RAND_DIMS);
+ if (iLop != null) {
+ sb.append(iLop.prepScalarInputOperand(getExecType()));
+ sb.append(OPERAND_DELIMITOR);
+ }
+ else {
+ iLop = _inputParams.get(DataExpression.RAND_ROWS);
+ sb.append(iLop.prepScalarInputOperand(getExecType()));
+ sb.append(OPERAND_DELIMITOR);
+
+ iLop = _inputParams.get(DataExpression.RAND_COLS);
+ sb.append(iLop.prepScalarInputOperand(getExecType()));
+ sb.append(OPERAND_DELIMITOR);
+ }
+ iLop = _inputParams.get(DataExpression.SCHEMAPARAM);
+ if ( iLop != null ) {
+ if(iLop instanceof Nary) {
+ for(Lop lop : iLop.getInputs()) {
+ sb.append(((Data)lop).getStringValue());
+ sb.append(DataExpression.DELIM_NA_STRING_SEP);
+ }
+ }
+ else if(iLop instanceof Data) {
+ sb.append(((Data)iLop).getStringValue());
+ }
+ }
+
+ sb.append(OPERAND_DELIMITOR);
+
+ if( getExecType() == ExecType.SPARK ) {
+ sb.append(baseDir);
+ sb.append(OPERAND_DELIMITOR);
+ }
+
+ sb.append( prepOutputOperand(output));
+ return sb.toString();
+ }
+
private String getSInitInstructionCPSpark(String output)
{
if ( _op != OpOpDG.SINIT )
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index aab0d22..8e33a15 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2093,6 +2093,12 @@
currBuiltinOp = new DataGenOp(method, target, paramHops);
break;
+ case FRAME:
+ // We limit RAND_MIN, RAND_MAX, RAND_SPARSITY, RAND_SEED, and RAND_PDF to be constants
+ method = OpOpDG.FRAMEINIT;
+ currBuiltinOp = new DataGenOp(method, target, paramHops);
+ break;
+
case TENSOR:
case MATRIX:
ArrayList<Hop> tmpMatrix = new ArrayList<>();
diff --git a/src/main/java/org/apache/sysds/parser/DataExpression.java b/src/main/java/org/apache/sysds/parser/DataExpression.java
index f9fe5a4..dcea873 100644
--- a/src/main/java/org/apache/sysds/parser/DataExpression.java
+++ b/src/main/java/org/apache/sysds/parser/DataExpression.java
@@ -28,6 +28,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
@@ -125,7 +126,10 @@
public static final Set<String> RESHAPE_VALID_PARAM_NAMES = new HashSet<>(
Arrays.asList(RAND_BY_ROW, RAND_DIMNAMES, RAND_DATA, RAND_ROWS, RAND_COLS, RAND_DIMS));
-
+
+ public static final Set<String> FRAME_VALID_PARAM_NAMES = new HashSet<>(
+ Arrays.asList(SCHEMAPARAM, RAND_DATA, RAND_ROWS, RAND_COLS));
+
public static final Set<String> SQL_VALID_PARAM_NAMES = new HashSet<>(
Arrays.asList(SQL_CONN, SQL_USER, SQL_PASS, SQL_QUERY));
@@ -156,7 +160,8 @@
public static final double DEFAULT_DELIM_FILL_VALUE = 0.0;
public static final boolean DEFAULT_DELIM_SPARSE = false;
public static final String DEFAULT_NA_STRINGS = "";
-
+ public static final String DEFAULT_SCHEMAPARAM = "NULL";
+
private DataOp _opcode;
private HashMap<String, Expression> _varParams;
private boolean _strInit = false; //string initialize
@@ -186,309 +191,361 @@
+ passedParamExprs + " " + parseInfo + " " + errorListener);
}
// check if the function name is built-in function
- // (assign built-in function op if function is built-in)
- Expression.DataOp dop;
+ // (assign built-in function op if function is built-in)
DataExpression dataExpr = null;
- if (functionName.equals("read") || functionName.equals("readMM") || functionName.equals("read.csv")) {
- dop = Expression.DataOp.READ;
- dataExpr = new DataExpression(dop, new HashMap<String, Expression>(), parseInfo);
+ if (functionName.equals("read") || functionName.equals("readMM") || functionName.equals("read.csv"))
+ dataExpr = processReadDataExpression(functionName, passedParamExprs, errorListener, parseInfo);
+ else if (functionName.equalsIgnoreCase("rand"))
+ dataExpr = processRandDataExpression(functionName, passedParamExprs, errorListener, parseInfo);
+ else if (functionName.equals("matrix"))
+ dataExpr = processMatrixExpression(functionName, passedParamExprs, errorListener, parseInfo);
+ else if (functionName.equals("frame"))
+ dataExpr = processFrameExpression(functionName, passedParamExprs, errorListener, parseInfo);
+ else if (functionName.equals("tensor"))
+ dataExpr = processTensorExpression(functionName, passedParamExprs, errorListener, parseInfo);
+ else if (functionName.equals("sql"))
+ dataExpr = processSQLExpression(functionName, passedParamExprs, errorListener, parseInfo);
+ else if (functionName.equals("federated"))
+ dataExpr = processFederatedExpression(functionName, passedParamExprs, errorListener, parseInfo);
+
+ if (dataExpr != null)
+ dataExpr.setParseInfo(parseInfo);
+ return dataExpr;
+ }
+
+ private static DataExpression processReadDataExpression(String functionName,
+ List<ParameterExpression> passedParamExprs, CustomErrorListener errorListener, ParseInfo parseInfo)
+ {
+ DataExpression dataExpr = new DataExpression(DataOp.READ, new HashMap<>(), parseInfo);
+ if (functionName.equals("readMM"))
+ dataExpr.addVarParam(DataExpression.FORMAT_TYPE,
+ new StringIdentifier(FileFormat.MM.toString(), parseInfo));
- if (functionName.equals("readMM"))
- dataExpr.addVarParam(DataExpression.FORMAT_TYPE,
- new StringIdentifier(FileFormat.MM.toString(), parseInfo));
+ if (functionName.equals("read.csv"))
+ dataExpr.addVarParam(DataExpression.FORMAT_TYPE,
+ new StringIdentifier(FileFormat.CSV.toString(), parseInfo));
- if (functionName.equals("read.csv"))
- dataExpr.addVarParam(DataExpression.FORMAT_TYPE,
- new StringIdentifier(FileFormat.CSV.toString(), parseInfo));
+ if (functionName.equals("read.libsvm"))
+ dataExpr.addVarParam(DataExpression.FORMAT_TYPE,
+ new StringIdentifier(FileFormat.LIBSVM.toString(), parseInfo));
- if (functionName.equals("read.libsvm"))
- dataExpr.addVarParam(DataExpression.FORMAT_TYPE,
- new StringIdentifier(FileFormat.LIBSVM.toString(), parseInfo));
-
- // validate the filename is the first parameter
- if (passedParamExprs.size() < 1){
- errorListener.validationError(parseInfo, "read method must have at least filename parameter");
- return null;
- }
-
- ParameterExpression pexpr = (passedParamExprs.size() == 0) ? null : passedParamExprs.get(0);
-
- if ( (pexpr != null) && (!(pexpr.getName() == null) || (pexpr.getName() != null && pexpr.getName().equalsIgnoreCase(DataExpression.IO_FILENAME)))){
- errorListener.validationError(parseInfo, "first parameter to read statement must be filename");
- return null;
- } else if( pexpr != null ){
- dataExpr.addVarParam(DataExpression.IO_FILENAME, pexpr.getExpr());
- }
-
- // validate all parameters are added only once and valid name
- for (int i = 1; i < passedParamExprs.size(); i++){
- String currName = passedParamExprs.get(i).getName();
- Expression currExpr = passedParamExprs.get(i).getExpr();
-
- if (dataExpr.getVarParam(currName) != null){
- errorListener.validationError(parseInfo, "attempted to add IOStatement parameter " + currName + " more than once");
- return null;
- }
- // verify parameter names for read function
- boolean isValidName = READ_VALID_PARAM_NAMES.contains(currName);
-
- if (!isValidName){
- errorListener.validationError(parseInfo, "attempted to add invalid read statement parameter " + currName);
- return null;
- }
- dataExpr.addVarParam(currName, currExpr);
- }
- }
- else if (functionName.equalsIgnoreCase("rand")){
-
- dop = Expression.DataOp.RAND;
- dataExpr = new DataExpression(dop, new HashMap<String, Expression>(), parseInfo);
-
- for (ParameterExpression currExpr : passedParamExprs){
- String pname = currExpr.getName();
- Expression pexpr = currExpr.getExpr();
- if (pname == null){
- errorListener.validationError(parseInfo, "for rand statement, all arguments must be named parameters");
- return null;
- }
- dataExpr.addRandExprParam(pname, pexpr);
- }
- dataExpr.setRandDefault();
+ // validate the filename is the first parameter
+ if (passedParamExprs.size() < 1){
+ errorListener.validationError(parseInfo, "read method must have at least filename parameter");
+ return null;
}
- else if (functionName.equals("matrix")){
- dop = Expression.DataOp.MATRIX;
- dataExpr = new DataExpression(dop, new HashMap<String, Expression>(), parseInfo);
+ ParameterExpression pexpr = (passedParamExprs.size() == 0) ? null : passedParamExprs.get(0);
- int namedParamCount = 0, unnamedParamCount = 0;
- for (ParameterExpression currExpr : passedParamExprs) {
- if (currExpr.getName() == null)
- unnamedParamCount++;
- else
- namedParamCount++;
+ if ( (pexpr != null) && (!(pexpr.getName() == null) || (pexpr.getName() != null && pexpr.getName().equalsIgnoreCase(DataExpression.IO_FILENAME)))){
+ errorListener.validationError(parseInfo, "first parameter to read statement must be filename");
+ return null;
+ } else if( pexpr != null ){
+ dataExpr.addVarParam(DataExpression.IO_FILENAME, pexpr.getExpr());
+ }
+
+ // validate all parameters are added only once and valid name
+ for (int i = 1; i < passedParamExprs.size(); i++){
+ String currName = passedParamExprs.get(i).getName();
+ Expression currExpr = passedParamExprs.get(i).getExpr();
+
+ if (dataExpr.getVarParam(currName) != null){
+ errorListener.validationError(parseInfo, "attempted to add IOStatement parameter " + currName + " more than once");
+ return null;
}
+ // verify parameter names for read function
+ boolean isValidName = READ_VALID_PARAM_NAMES.contains(currName);
- // check whether named or unnamed parameters are used
- if (passedParamExprs.size() < 3){
+ if (!isValidName){
+ errorListener.validationError(parseInfo, "attempted to add invalid read statement parameter " + currName);
+ return null;
+ }
+ dataExpr.addVarParam(currName, currExpr);
+ }
+
+ return dataExpr;
+ }
+
+ private static DataExpression processRandDataExpression(String functionName,
+ List<ParameterExpression> passedParamExprs, CustomErrorListener errorListener, ParseInfo parseInfo)
+ {
+ DataExpression dataExpr = new DataExpression(DataOp.RAND, new HashMap<>(), parseInfo);
+
+ for (ParameterExpression currExpr : passedParamExprs){
+ String pname = currExpr.getName();
+ Expression pexpr = currExpr.getExpr();
+ if (pname == null){
+ errorListener.validationError(parseInfo, "for rand statement, all arguments must be named parameters");
+ return null;
+ }
+ dataExpr.addRandExprParam(pname, pexpr);
+ }
+ dataExpr.setRandDefault();
+ return dataExpr;
+ }
+
+ private static DataExpression processMatrixExpression(String functionName,
+ List<ParameterExpression> passedParamExprs, CustomErrorListener errorListener, ParseInfo parseInfo)
+ {
+ DataExpression dataExpr = new DataExpression(DataOp.MATRIX, new HashMap<>(), parseInfo);
+ int namedParamCount = (int) passedParamExprs.stream().filter(p -> p.getName()!=null).count();
+ int unnamedParamCount = passedParamExprs.size() - namedParamCount;
+
+ // check whether named or unnamed parameters are used
+ if (passedParamExprs.size() < 3){
+ errorListener.validationError(parseInfo, "for matrix statement, must specify at least 3 arguments: data, rows, cols");
+ return null;
+ }
+
+ if (unnamedParamCount > 1){
+ if (namedParamCount > 0) {
+ errorListener.validationError(parseInfo, "for matrix statement, cannot mix named and unnamed parameters");
+ return null;
+ }
+ if (unnamedParamCount < 3) {
errorListener.validationError(parseInfo, "for matrix statement, must specify at least 3 arguments: data, rows, cols");
return null;
}
-
- if (unnamedParamCount > 1){
-
- if (namedParamCount > 0) {
- errorListener.validationError(parseInfo, "for matrix statement, cannot mix named and unnamed parameters");
- return null;
- }
-
- if (unnamedParamCount < 3) {
- errorListener.validationError(parseInfo, "for matrix statement, must specify at least 3 arguments: data, rows, cols");
- return null;
- }
-
- // assume: data, rows, cols, [byRow], [dimNames]
- dataExpr.addMatrixExprParam(DataExpression.RAND_DATA,passedParamExprs.get(0).getExpr());
- dataExpr.addMatrixExprParam(DataExpression.RAND_ROWS,passedParamExprs.get(1).getExpr());
- dataExpr.addMatrixExprParam(DataExpression.RAND_COLS,passedParamExprs.get(2).getExpr());
-
- if (unnamedParamCount >= 4)
- dataExpr.addMatrixExprParam(DataExpression.RAND_BY_ROW,passedParamExprs.get(3).getExpr());
-
- if (unnamedParamCount == 5)
- dataExpr.addMatrixExprParam(DataExpression.RAND_DIMNAMES,passedParamExprs.get(4).getExpr());
-
- if (unnamedParamCount > 5) {
- errorListener.validationError(parseInfo, "for matrix statement, at most 5 arguments supported: data, rows, cols, byrow, dimname");
- return null;
- }
-
+ // assume: data, rows, cols, [byRow], [dimNames]
+ dataExpr.addMatrixExprParam(DataExpression.RAND_DATA,passedParamExprs.get(0).getExpr());
+ dataExpr.addMatrixExprParam(DataExpression.RAND_ROWS,passedParamExprs.get(1).getExpr());
+ dataExpr.addMatrixExprParam(DataExpression.RAND_COLS,passedParamExprs.get(2).getExpr());
+
+ if (unnamedParamCount >= 4)
+ dataExpr.addMatrixExprParam(DataExpression.RAND_BY_ROW,passedParamExprs.get(3).getExpr());
+ if (unnamedParamCount == 5)
+ dataExpr.addMatrixExprParam(DataExpression.RAND_DIMNAMES,passedParamExprs.get(4).getExpr());
+ if (unnamedParamCount > 5) {
+ errorListener.validationError(parseInfo, "for matrix statement, at most 5 arguments supported: data, rows, cols, byrow, dimname");
+ return null;
+ }
+ }
+ else {
+ // handle first parameter, which is data and may be unnamed
+ ParameterExpression firstParam = passedParamExprs.get(0);
+ if (firstParam.getName() != null && !firstParam.getName().equals(DataExpression.RAND_DATA)){
+ errorListener.validationError(parseInfo, "matrix method must have data parameter as first parameter or unnamed parameter");
+ return null;
} else {
- // handle first parameter, which is data and may be unnamed
- ParameterExpression firstParam = passedParamExprs.get(0);
- if (firstParam.getName() != null && !firstParam.getName().equals(DataExpression.RAND_DATA)){
- errorListener.validationError(parseInfo, "matrix method must have data parameter as first parameter or unnamed parameter");
+ dataExpr.addMatrixExprParam(DataExpression.RAND_DATA, passedParamExprs.get(0).getExpr());
+ }
+
+ for (int i=1; i<passedParamExprs.size(); i++){
+ if (passedParamExprs.get(i).getName() == null){
+ errorListener.validationError(parseInfo, "for matrix statement, cannot mix named and unnamed parameters, only data parameter can be unnammed");
return null;
} else {
- dataExpr.addMatrixExprParam(DataExpression.RAND_DATA, passedParamExprs.get(0).getExpr());
- }
-
- for (int i=1; i<passedParamExprs.size(); i++){
- if (passedParamExprs.get(i).getName() == null){
- errorListener.validationError(parseInfo, "for matrix statement, cannot mix named and unnamed parameters, only data parameter can be unnammed");
- return null;
- } else {
- dataExpr.addMatrixExprParam(passedParamExprs.get(i).getName(), passedParamExprs.get(i).getExpr());
- }
+ dataExpr.addMatrixExprParam(passedParamExprs.get(i).getName(), passedParamExprs.get(i).getExpr());
}
}
- dataExpr.setMatrixDefault();
}
- else if (functionName.equals("tensor")){
- dop = Expression.DataOp.TENSOR;
- dataExpr = new DataExpression(dop, new HashMap<String, Expression>(), parseInfo);
-
- int namedParamCount = 0, unnamedParamCount = 0;
- for (ParameterExpression currExpr : passedParamExprs) {
- if (currExpr.getName() == null)
- unnamedParamCount++;
- else
- namedParamCount++;
- }
-
- // check whether named or unnamed parameters are used
- if (passedParamExprs.size() < 2){
- errorListener.validationError(parseInfo, "for tensor statement, must specify at least 2 arguments: data, dims[]");
- return null;
- }
-
- if (unnamedParamCount > 1){
- if (namedParamCount > 0) {
- errorListener.validationError(parseInfo, "for tensor statement, cannot mix named and unnamed parameters");
- return null;
- }
-
- // assume: data, dims[], [byRow], [dimNames]
- dataExpr.addTensorExprParam(DataExpression.RAND_DATA,passedParamExprs.get(0).getExpr());
- dataExpr.addTensorExprParam(DataExpression.RAND_DIMS,passedParamExprs.get(1).getExpr());
-
- if (unnamedParamCount >= 3)
- // TODO use byRow parameter
- dataExpr.addTensorExprParam(DataExpression.RAND_BY_ROW,passedParamExprs.get(2).getExpr());
-
- if (unnamedParamCount == 4)
- dataExpr.addTensorExprParam(DataExpression.RAND_DIMNAMES,passedParamExprs.get(3).getExpr());
-
- if (unnamedParamCount > 4) {
- errorListener.validationError(parseInfo, "for tensor statement, at most 4 arguments supported: data, dims, byrow, dimname");
- return null;
- }
-
- }
- else {
- // handle first parameter, which is data and may be unnamed
- ParameterExpression firstParam = passedParamExprs.get(0);
- if (firstParam.getName() != null && !firstParam.getName().equals(DataExpression.RAND_DATA)){
- errorListener.validationError(parseInfo, "tensor method must have data parameter as first parameter or unnamed parameter");
- return null;
- }
- else {
- dataExpr.addTensorExprParam(DataExpression.RAND_DATA, passedParamExprs.get(0).getExpr());
- }
-
- for (int i=1; i<passedParamExprs.size(); i++){
- if (passedParamExprs.get(i).getName() == null){
- errorListener.validationError(parseInfo, "for tensor statement, cannot mix named and unnamed parameters, only data parameter can be unnammed");
- return null;
- }
- else {
- dataExpr.addTensorExprParam(passedParamExprs.get(i).getName(), passedParamExprs.get(i).getExpr());
- }
- }
- }
- dataExpr.setTensorDefault();
- }
- else if (functionName.equals("sql")) {
- dop = DataOp.SQL;
- dataExpr = new DataExpression(dop, new HashMap<>(), parseInfo);
-
- int namedParamCount = 0, unnamedParamCount = 0;
- for (ParameterExpression currExpr : passedParamExprs) {
- if (currExpr.getName() == null)
- unnamedParamCount++;
- else
- namedParamCount++;
- }
-
- // check whether named or unnamed parameters are used
- if (passedParamExprs.size() < 2){
- errorListener.validationError(parseInfo, "for sql statement, must specify at least 2 arguments: conn, query");
- return null;
- }
-
- if (unnamedParamCount > 0){
- if (namedParamCount > 0) {
- errorListener.validationError(parseInfo, "for sql statement, cannot mix named and unnamed parameters");
- return null;
- }
-
- if (unnamedParamCount == 2 || unnamedParamCount == 4 ) {
- // assume: conn, query, [password, query]
- dataExpr.addSqlExprParam(DataExpression.SQL_CONN, passedParamExprs.get(0).getExpr());
- dataExpr.addSqlExprParam(DataExpression.SQL_QUERY, passedParamExprs.get(1).getExpr());
- if (unnamedParamCount == 4) {
- dataExpr.addSqlExprParam(DataExpression.SQL_PASS, passedParamExprs.get(2).getExpr());
- dataExpr.addSqlExprParam(DataExpression.SQL_QUERY, passedParamExprs.get(3).getExpr());
- }
- }
- else {
- errorListener.validationError(parseInfo, "for sql statement, "
- + "at most 4 arguments supported: conn, user, password, query");
- return null;
- }
-
- }
- else {
- for (ParameterExpression passedParamExpr : passedParamExprs) {
- dataExpr.addSqlExprParam(passedParamExpr.getName(), passedParamExpr.getExpr());
- }
- }
- dataExpr.setSqlDefault();
- }
- else if (functionName.equals("federated")) {
- dop = DataOp.FEDERATED;
- dataExpr = new DataExpression(dop, new HashMap<>(), parseInfo);
- int namedParamCount = 0, unnamedParamCount = 0;
- for (ParameterExpression currExpr : passedParamExprs) {
- if (currExpr.getName() == null)
- unnamedParamCount++;
- else
- namedParamCount++;
- }
- if(passedParamExprs.size() < 2) {
- errorListener.validationError(parseInfo,
- "for federated statement, must specify at least 2 arguments: addresses, ranges");
- return null;
- }
- if(unnamedParamCount > 0) {
- if(namedParamCount > 0) {
- errorListener.validationError(parseInfo,
- "for federated statement, cannot mix named and unnamed parameters");
- return null;
- }
- if(unnamedParamCount == 2) {
- // first parameter addresses second are the ranges (type defaults to Matrix)
- ParameterExpression param = passedParamExprs.get(0);
- dataExpr.addFederatedExprParam(DataExpression.FED_ADDRESSES, param.getExpr());
- param = passedParamExprs.get(1);
- dataExpr.addFederatedExprParam(DataExpression.FED_RANGES, param.getExpr());
- }
- else if(unnamedParamCount == 3) {
- ParameterExpression param = passedParamExprs.get(0);
- dataExpr.addFederatedExprParam(DataExpression.FED_ADDRESSES, param.getExpr());
- param = passedParamExprs.get(1);
- dataExpr.addFederatedExprParam(DataExpression.FED_RANGES, param.getExpr());
- param = passedParamExprs.get(2);
- dataExpr.addFederatedExprParam(DataExpression.FED_TYPE, param.getExpr());
- }
- else {
- errorListener.validationError(parseInfo,
- "for federated statement, at most 3 arguments are supported: addresses, ranges, type");
- }
- }
- else {
- for (ParameterExpression passedParamExpr : passedParamExprs) {
- dataExpr.addFederatedExprParam(passedParamExpr.getName(), passedParamExpr.getExpr());
- }
- }
- dataExpr.setFederatedDefault();
- }
-
- if (dataExpr != null) {
- dataExpr.setParseInfo(parseInfo);
- }
+ dataExpr.setMatrixDefault();
return dataExpr;
- } // end method getBuiltinFunctionExpression
+ }
+
+ private static DataExpression processFrameExpression(String functionName,
+ List<ParameterExpression> passedParamExprs, CustomErrorListener errorListener, ParseInfo parseInfo)
+ {
+ DataExpression dataExpr = new DataExpression(DataOp.FRAME, new HashMap<>(), parseInfo);
+ int namedParamCount = (int) passedParamExprs.stream().filter(p -> p.getName()!=null).count();
+ int unnamedParamCount = passedParamExprs.size() - namedParamCount;
+
+ // check whether named or unnamed parameters are used
+ if (passedParamExprs.size() < 3) { // it will generate a frame with string schema
+ errorListener.validationError(parseInfo, "for frame statement, must specify at least 3 arguments: data, rows and cols");
+ return null;
+ }
+
+ if (unnamedParamCount > 1) {
+ if (namedParamCount > 0) {
+ errorListener.validationError(parseInfo, "for frame statement, cannot mix named and unnamed parameters");
+ return null;
+ }
+ if (unnamedParamCount < 3) {
+ errorListener.validationError(parseInfo, "for frame statement, must specify at least 3 arguments: rows, cols");
+ return null;
+ }
+ // assume: data, rows, cols, [Schema]
+ dataExpr.addFrameExprParam(DataExpression.RAND_DATA, passedParamExprs.get(0).getExpr());
+ dataExpr.addFrameExprParam(DataExpression.RAND_ROWS, passedParamExprs.get(1).getExpr());
+ dataExpr.addFrameExprParam(DataExpression.RAND_COLS, passedParamExprs.get(2).getExpr());
+
+ if (unnamedParamCount == 3)
+ dataExpr.addFrameExprParam(DataExpression.SCHEMAPARAM, passedParamExprs.get(3).getExpr());
+ if (unnamedParamCount > 3) {
+ errorListener.validationError(parseInfo, "for frame statement, at most 4 arguments supported: data, rows, cols, schema");
+ return null;
+ }
+ }
+ else {
+ // handle first parameter, which is data and may be unnamed
+ ParameterExpression firstParam = passedParamExprs.get(0);
+ if (firstParam.getName() != null && !firstParam.getName().equals(DataExpression.RAND_DATA)){
+ errorListener.validationError(parseInfo, "frame method must have data parameter as first parameter or unnamed parameter");
+ return null;
+ }
+ else {
+ dataExpr.addFrameExprParam(DataExpression.RAND_DATA, passedParamExprs.get(0).getExpr());
+ }
+
+ for (int i=1; i<passedParamExprs.size(); i++){
+ if (passedParamExprs.get(i).getName() == null){
+ errorListener.validationError(parseInfo, "for frame statement, cannot mix named and unnamed parameters, only data parameter can be unnammed");
+ return null;
+ } else {
+ dataExpr.addFrameExprParam(passedParamExprs.get(i).getName(), passedParamExprs.get(i).getExpr());
+ }
+ }
+ }
+ dataExpr.setFrameDefault();
+ return dataExpr;
+ }
+
+ private static DataExpression processTensorExpression(String functionName,
+ List<ParameterExpression> passedParamExprs, CustomErrorListener errorListener, ParseInfo parseInfo)
+ {
+ DataExpression dataExpr = new DataExpression(DataOp.TENSOR, new HashMap<>(), parseInfo);
+ int namedParamCount = (int) passedParamExprs.stream().filter(p -> p.getName()!=null).count();
+ int unnamedParamCount = passedParamExprs.size() - namedParamCount;
+
+ // check whether named or unnamed parameters are used
+ if (passedParamExprs.size() < 2){
+ errorListener.validationError(parseInfo, "for tensor statement, must specify at least 2 arguments: data, dims[]");
+ return null;
+ }
+ if (unnamedParamCount > 1){
+ if (namedParamCount > 0) {
+ errorListener.validationError(parseInfo, "for tensor statement, cannot mix named and unnamed parameters");
+ return null;
+ }
+
+ // assume: data, dims[], [byRow], [dimNames]
+ dataExpr.addTensorExprParam(DataExpression.RAND_DATA,passedParamExprs.get(0).getExpr());
+ dataExpr.addTensorExprParam(DataExpression.RAND_DIMS,passedParamExprs.get(1).getExpr());
+
+ if (unnamedParamCount >= 3)
+ // TODO use byRow parameter
+ dataExpr.addTensorExprParam(DataExpression.RAND_BY_ROW,passedParamExprs.get(2).getExpr());
+ if (unnamedParamCount == 4)
+ dataExpr.addTensorExprParam(DataExpression.RAND_DIMNAMES,passedParamExprs.get(3).getExpr());
+ if (unnamedParamCount > 4) {
+ errorListener.validationError(parseInfo, "for tensor statement, at most 4 arguments supported: data, dims, byrow, dimname");
+ return null;
+ }
+ }
+ else {
+ // handle first parameter, which is data and may be unnamed
+ ParameterExpression firstParam = passedParamExprs.get(0);
+ if (firstParam.getName() != null && !firstParam.getName().equals(DataExpression.RAND_DATA)){
+ errorListener.validationError(parseInfo, "tensor method must have data parameter as first parameter or unnamed parameter");
+ return null;
+ }
+ else {
+ dataExpr.addTensorExprParam(DataExpression.RAND_DATA, passedParamExprs.get(0).getExpr());
+ }
+
+ for (int i=1; i<passedParamExprs.size(); i++){
+ if (passedParamExprs.get(i).getName() == null){
+ errorListener.validationError(parseInfo, "for tensor statement, cannot mix named and unnamed parameters, only data parameter can be unnammed");
+ return null;
+ }
+ else {
+ dataExpr.addTensorExprParam(passedParamExprs.get(i).getName(), passedParamExprs.get(i).getExpr());
+ }
+ }
+ }
+ dataExpr.setTensorDefault();
+ return dataExpr;
+ }
+
+ private static DataExpression processSQLExpression(String functionName,
+ List<ParameterExpression> passedParamExprs, CustomErrorListener errorListener, ParseInfo parseInfo)
+ {
+ DataExpression dataExpr = new DataExpression(DataOp.SQL, new HashMap<>(), parseInfo);
+ int namedParamCount = (int) passedParamExprs.stream().filter(p -> p.getName()!=null).count();
+ int unnamedParamCount = passedParamExprs.size() - namedParamCount;
+
+ // check whether named or unnamed parameters are used
+ if (passedParamExprs.size() < 2){
+ errorListener.validationError(parseInfo, "for sql statement, must specify at least 2 arguments: conn, query");
+ return null;
+ }
+ if (unnamedParamCount > 0){
+ if (namedParamCount > 0) {
+ errorListener.validationError(parseInfo, "for sql statement, cannot mix named and unnamed parameters");
+ return null;
+ }
+ if (unnamedParamCount == 2 || unnamedParamCount == 4 ) {
+ // assume: conn, query, [password, query]
+ dataExpr.addSqlExprParam(DataExpression.SQL_CONN, passedParamExprs.get(0).getExpr());
+ dataExpr.addSqlExprParam(DataExpression.SQL_QUERY, passedParamExprs.get(1).getExpr());
+ if (unnamedParamCount == 4) {
+ dataExpr.addSqlExprParam(DataExpression.SQL_PASS, passedParamExprs.get(2).getExpr());
+ dataExpr.addSqlExprParam(DataExpression.SQL_QUERY, passedParamExprs.get(3).getExpr());
+ }
+ }
+ else {
+ errorListener.validationError(parseInfo, "for sql statement, "
+ + "at most 4 arguments supported: conn, user, password, query");
+ return null;
+ }
+ }
+ else {
+ for (ParameterExpression passedParamExpr : passedParamExprs) {
+ dataExpr.addSqlExprParam(passedParamExpr.getName(), passedParamExpr.getExpr());
+ }
+ }
+ dataExpr.setSqlDefault();
+ return dataExpr;
+ }
+
+ private static DataExpression processFederatedExpression(String functionName,
+ List<ParameterExpression> passedParamExprs, CustomErrorListener errorListener, ParseInfo parseInfo)
+ {
+ DataExpression dataExpr = new DataExpression(DataOp.FEDERATED, new HashMap<>(), parseInfo);
+ int namedParamCount = (int) passedParamExprs.stream().filter(p -> p.getName()!=null).count();
+ int unnamedParamCount = passedParamExprs.size() - namedParamCount;
+
+ if(passedParamExprs.size() < 2) {
+ errorListener.validationError(parseInfo,
+ "for federated statement, must specify at least 2 arguments: addresses, ranges");
+ return null;
+ }
+ if(unnamedParamCount > 0) {
+ if(namedParamCount > 0) {
+ errorListener.validationError(parseInfo,
+ "for federated statement, cannot mix named and unnamed parameters");
+ return null;
+ }
+ if(unnamedParamCount == 2) {
+ // first parameter addresses second are the ranges (type defaults to Matrix)
+ ParameterExpression param = passedParamExprs.get(0);
+ dataExpr.addFederatedExprParam(DataExpression.FED_ADDRESSES, param.getExpr());
+ param = passedParamExprs.get(1);
+ dataExpr.addFederatedExprParam(DataExpression.FED_RANGES, param.getExpr());
+ }
+ else if(unnamedParamCount == 3) {
+ ParameterExpression param = passedParamExprs.get(0);
+ dataExpr.addFederatedExprParam(DataExpression.FED_ADDRESSES, param.getExpr());
+ param = passedParamExprs.get(1);
+ dataExpr.addFederatedExprParam(DataExpression.FED_RANGES, param.getExpr());
+ param = passedParamExprs.get(2);
+ dataExpr.addFederatedExprParam(DataExpression.FED_TYPE, param.getExpr());
+ }
+ else {
+ errorListener.validationError(parseInfo,
+ "for federated statement, at most 3 arguments are supported: addresses, ranges, type");
+ }
+ }
+ else {
+ for (ParameterExpression passedParamExpr : passedParamExprs) {
+ dataExpr.addFederatedExprParam(passedParamExpr.getName(), passedParamExpr.getExpr());
+ }
+ }
+ dataExpr.setFederatedDefault();
+ return dataExpr;
+ }
public void addRandExprParam(String paramName, Expression paramValue)
{
@@ -544,6 +601,32 @@
paramValue.setParseInfo(this);
addVarParam(paramName,paramValue);
}
+ public void addFrameExprParam(String paramName, Expression paramValue)
+ {
+ // check name is valid
+ boolean found = FRAME_VALID_PARAM_NAMES.contains(paramName);
+
+ if (!found){
+ raiseValidateError("unexpected parameter \"" + paramName +
+ "\". Legal parameters for frame statement are "
+ + "(capitalization-sensitive): " + RAND_DATA + ", " + RAND_ROWS
+ + ", " + RAND_COLS + ", " + SCHEMAPARAM);
+ }
+ if (getVarParam(paramName) != null) {
+ raiseValidateError("attempted to add frame statement parameter " + paramValue + " more than once");
+ }
+// TODO convert double Matrix to String Frame
+ // Process the case where user provides double values to rows or cols
+// if (paramName.equals(RAND_ROWS) && paramValue instanceof StringIdentifier) {
+// paramValue = new IntIdentifier((long) ((DoubleIdentifier) paramValue).getValue(), this);
+// } else if (paramName.equals(RAND_COLS) && paramValue instanceof DoubleIdentifier) {
+// paramValue = new IntIdentifier((long) ((DoubleIdentifier) paramValue).getValue(), this);
+// }
+
+ // add the parameter to expression list
+ paramValue.setParseInfo(this);
+ addVarParam(paramName,paramValue);
+ }
public void addTensorExprParam(String paramName, Expression paramValue)
{
@@ -641,6 +724,13 @@
addVarParam(RAND_BY_ROW, new BooleanIdentifier(true, this));
}
+ public void setFrameDefault(){
+ if(getVarParam(RAND_DATA) == null)
+ addVarParam(RAND_DATA, new StringIdentifier(null, this));
+ if (getVarParam(SCHEMAPARAM) == null)
+ addVarParam(SCHEMAPARAM, new StringIdentifier(DEFAULT_SCHEMAPARAM, this));
+ }
+
public void setTensorDefault(){
if (getVarParam(RAND_BY_ROW) == null)
addVarParam(RAND_BY_ROW, new BooleanIdentifier(true, this));
@@ -792,7 +882,7 @@
}
inputParamExpr.validateExpression(ids, currConstVars, conditional);
if (s != null && !s.equals(RAND_DATA) && !s.equals(RAND_DIMS) && !s.equals(FED_ADDRESSES) && !s.equals(FED_RANGES)
- && !s.equals(DELIM_NA_STRINGS) && getVarParam(s).getOutput().getDataType() != DataType.SCALAR ) {
+ && !s.equals(DELIM_NA_STRINGS) && !s.equals(SCHEMAPARAM) && getVarParam(s).getOutput().getDataType() != DataType.SCALAR ) {
raiseValidateError("Non-scalar data types are not supported for data expression.", conditional,LanguageErrorCodes.INVALID_PARAMETERS);
}
}
@@ -804,20 +894,19 @@
// check if data parameter of matrix is scalar or matrix -- if scalar, use Rand instead
Expression dataParam1 = getVarParam(RAND_DATA);
if (dataParam1 == null && (getOpCode().equals(DataOp.MATRIX) || getOpCode().equals(DataOp.TENSOR))){
- raiseValidateError("for matrix or tensor, must defined data parameter", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
+ raiseValidateError("for matrix, frame or tensor, must defined data parameter", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
// We need to remember the operation if we replace the OpCode by rand so we have the correct output
- if (dataParam1 != null && dataParam1.getOutput().getDataType() == DataType.SCALAR &&
+ if (dataParam1!=null && dataParam1.getOutput()!=null && dataParam1.getOutput().getDataType() == DataType.SCALAR &&
(_opcode == DataOp.MATRIX || _opcode == DataOp.TENSOR)/*&& dataParam instanceof ConstIdentifier*/ ){
- //MB: note we should not check for const identifiers here, because otherwise all matrix constructors with
+ //MB: note we must not check for const identifiers here, because otherwise all matrix constructors with
//variable input are routed to a reshape operation (but it works only on matrices and hence, crashes)
// replace DataOp MATRIX with RAND -- Rand handles matrix generation for Scalar values
// replace data parameter with min / max within Rand case below
this.setOpCode(DataOp.RAND);
}
-
-
+
// IMPORTANT: for each operation, one must handle unnamed parameters
switch (this.getOpCode()) {
@@ -1732,7 +1821,7 @@
else {
raiseValidateError("In matrix statement, can only assign rows a long " +
"(integer) value >= 1 -- attempted to assign value: " + colsExpr.toString(), conditional);
- }
+ }
}
else if (colsExpr instanceof DataIdentifier && !(colsExpr instanceof IndexedIdentifier)) {
@@ -1758,7 +1847,6 @@
}
// handle double constant
else if (constValue instanceof DoubleIdentifier){
-
if (((DoubleIdentifier)constValue).getValue() < 1){
raiseValidateError("In matrix statement, can only assign cols a long " +
"(integer) value >= 1 -- attempted to assign value: "
@@ -1768,8 +1856,7 @@
long roundedValue = Double.valueOf(Math.floor(((DoubleIdentifier)constValue).getValue())).longValue();
colsExpr = new IntIdentifier(roundedValue, this);
addVarParam(RAND_COLS, colsExpr);
- colsLong = roundedValue;
-
+ colsLong = roundedValue;
}
else {
// exception -- rows must be integer or double constant
@@ -1781,29 +1868,190 @@
// handle general expression
colsExpr.validateExpression(ids, currConstVars, conditional);
}
-
- }
+ }
else {
// handle general expression
colsExpr.validateExpression(ids, currConstVars, conditional);
}
- }
+ }
getOutput().setFileFormat(FileFormat.BINARY);
getOutput().setDataType(DataType.MATRIX);
getOutput().setValueType(ValueType.FP64);
getOutput().setDimensions(rowsLong, colsLong);
-
+
if (getOutput() instanceof IndexedIdentifier){
((IndexedIdentifier) getOutput()).setOriginalDimensions(getOutput().getDim1(), getOutput().getDim2());
- }
- //getOutput().computeDataType();
-
- if (getOutput() instanceof IndexedIdentifier){
LOG.warn(this.printWarningLocation() + "Output for matrix Statement may have incorrect size information");
}
break;
+ case FRAME:
+ //handle default and input arguments
+ setFrameDefault();
+ validateParams(conditional, FRAME_VALID_PARAM_NAMES,
+ "Legal parameters for frame statement are (case-sensitive): "
+ + RAND_DATA + ", " + RAND_ROWS + ", " + RAND_COLS + ", " + SCHEMAPARAM);
+
+ //validate correct value types
+ if (getVarParam(RAND_ROWS) != null && (getVarParam(RAND_ROWS) instanceof StringIdentifier || getVarParam(RAND_ROWS) instanceof BooleanIdentifier)){
+ raiseValidateError("for frame statement " + RAND_ROWS + " has incorrect value type", conditional);
+ }
+ if (getVarParam(RAND_COLS) != null && (getVarParam(RAND_COLS) instanceof StringIdentifier || getVarParam(RAND_COLS) instanceof BooleanIdentifier)){
+ raiseValidateError("for frame statement " + RAND_COLS + " has incorrect value type", conditional);
+ }
+
+ //validate general data expression
+ getVarParam(RAND_DATA).validateExpression(ids, currConstVars, conditional);
+
+ rowsLong = -1L;
+ colsLong = -1L;
+
+ ///////////////////////////////////////////////////////////////////
+ // HANDLE ROWS
+ ///////////////////////////////////////////////////////////////////
+ rowsExpr = getVarParam(RAND_ROWS);
+ if (rowsExpr != null){
+ if (rowsExpr instanceof IntIdentifier) {
+ if (((IntIdentifier)rowsExpr).getValue() >= 1 )
+ rowsLong = ((IntIdentifier)rowsExpr).getValue();
+ else
+ raiseValidateError("In frame statement, can only assign rows a long " +
+ "(integer) value >= 1 -- attempted to assign value: " + ((IntIdentifier)rowsExpr).getValue(), conditional);
+ }
+ else if (rowsExpr instanceof DoubleIdentifier) {
+ if (((DoubleIdentifier)rowsExpr).getValue() >= 1 )
+ rowsLong = Double.valueOf((Math.floor(((DoubleIdentifier)rowsExpr).getValue()))).longValue();
+ else
+ raiseValidateError("In frame statement, can only assign rows a long " +
+ "(integer) value >= 1 -- attempted to assign value: " + rowsExpr.toString(), conditional);
+ }
+ else if (rowsExpr instanceof DataIdentifier && !(rowsExpr instanceof IndexedIdentifier)) {
+ // check if the DataIdentifier variable is a ConstIdentifier
+ String identifierName = ((DataIdentifier)rowsExpr).getName();
+ if (currConstVars.containsKey(identifierName)){
+ // handle int constant
+ ConstIdentifier constValue = currConstVars.get(identifierName);
+ if (constValue instanceof IntIdentifier){
+ // check rows is >= 1 --- throw exception
+ if (((IntIdentifier)constValue).getValue() < 1){
+ raiseValidateError("In frame statement, can only assign rows a long " +
+ "(integer) value >= 1 -- attempted to assign value: " + constValue.toString(), conditional);
+ }
+ // update row expr with new IntIdentifier
+ long roundedValue = ((IntIdentifier)constValue).getValue();
+ rowsExpr = new IntIdentifier(roundedValue, this);
+ addVarParam(RAND_ROWS, rowsExpr);
+ rowsLong = roundedValue;
+ }
+ // handle double constant
+ else if (constValue instanceof DoubleIdentifier){
+ if (((DoubleIdentifier)constValue).getValue() < 1.0){
+ raiseValidateError("In frame statement, can only assign rows a long " +
+ "(integer) value >= 1 -- attempted to assign value: " + constValue.toString(), conditional);
+ }
+ // update row expr with new IntIdentifier (rounded down)
+ long roundedValue = Double.valueOf(Math.floor(((DoubleIdentifier)constValue).getValue())).longValue();
+ rowsExpr = new IntIdentifier(roundedValue, this);
+ addVarParam(RAND_ROWS, rowsExpr);
+ rowsLong = roundedValue;
+ }
+ else {
+ // exception -- rows must be integer or double constant
+ raiseValidateError("In frame statement, can only assign rows a long " +
+ "(integer) value >= 1 -- attempted to assign value: " + constValue.toString(), conditional);
+ }
+ }
+ else {
+ // handle general expression
+ rowsExpr.validateExpression(ids, currConstVars, conditional);
+ }
+ }
+ else {
+ // handle general expression
+ rowsExpr.validateExpression(ids, currConstVars, conditional);
+ }
+ }
+
+ ///////////////////////////////////////////////////////////////////
+ // HANDLE COLUMNS
+ ///////////////////////////////////////////////////////////////////
+
+ colsExpr = getVarParam(RAND_COLS);
+ if (colsExpr != null){
+ if (colsExpr instanceof IntIdentifier) {
+ if (((IntIdentifier)colsExpr).getValue() >= 1 )
+ colsLong = ((IntIdentifier)colsExpr).getValue();
+ else
+ raiseValidateError("In frame statement, can only assign cols a long " +
+ "(integer) value >= 1 -- attempted to assign value: " + colsExpr.toString(), conditional);
+ }
+ else if (colsExpr instanceof DoubleIdentifier) {
+ if (((DoubleIdentifier)colsExpr).getValue() >= 1 )
+ colsLong = Double.valueOf((Math.floor(((DoubleIdentifier)colsExpr).getValue()))).longValue();
+ else
+ raiseValidateError("In frame statement, can only assign rows a long " +
+ "(integer) value >= 1 -- attempted to assign value: " + colsExpr.toString(), conditional);
+ }
+ else if (colsExpr instanceof DataIdentifier && !(colsExpr instanceof IndexedIdentifier)) {
+ // check if the DataIdentifier variable is a ConstIdentifier
+ String identifierName = ((DataIdentifier)colsExpr).getName();
+ if (currConstVars.containsKey(identifierName)){
+ // handle int constant
+ ConstIdentifier constValue = currConstVars.get(identifierName);
+ if (constValue instanceof IntIdentifier){
+ // check cols is >= 1 --- throw exception
+ if (((IntIdentifier)constValue).getValue() < 1){
+ raiseValidateError("In frame statement, can only assign cols a long " +
+ "(integer) value >= 1 -- attempted to assign value: "
+ + constValue.toString(), conditional);
+ }
+ // update col expr with new IntIdentifier
+ long roundedValue = ((IntIdentifier)constValue).getValue();
+ colsExpr = new IntIdentifier(roundedValue, this);
+ addVarParam(RAND_COLS, colsExpr);
+ colsLong = roundedValue;
+ }
+ // handle double constant
+ else if (constValue instanceof DoubleIdentifier){
+ if (((DoubleIdentifier)constValue).getValue() < 1){
+ raiseValidateError("In frame statement, can only assign cols a long " +
+ "(integer) value >= 1 -- attempted to assign value: "
+ + constValue.toString(), conditional);
+ }
+ // update col expr with new IntIdentifier (rounded down)
+ long roundedValue = Double.valueOf(Math.floor(((DoubleIdentifier)constValue).getValue())).longValue();
+ colsExpr = new IntIdentifier(roundedValue, this);
+ addVarParam(RAND_COLS, colsExpr);
+ colsLong = roundedValue;
+ }
+ else {
+ // exception -- rows must be integer or double constant
+ raiseValidateError("In frame statement, can only assign cols a long " +
+ "(integer) value >= 1 -- attempted to assign value: " + constValue.toString(), conditional);
+ }
+ }
+ else {
+ // handle general expression
+ colsExpr.validateExpression(ids, currConstVars, conditional);
+ }
+ }
+ else {
+ // handle general expression
+ colsExpr.validateExpression(ids, currConstVars, conditional);
+ }
+ }
+ getOutput().setFileFormat(FileFormat.BINARY);
+ getOutput().setDataType(DataType.FRAME);
+ getOutput().setValueType(ValueType.UNKNOWN);
+ getOutput().setDimensions(rowsLong, colsLong);
+
+ if (getOutput() instanceof IndexedIdentifier){
+ ((IndexedIdentifier) getOutput()).setOriginalDimensions(getOutput().getDim1(), getOutput().getDim2());
+ LOG.warn(this.printWarningLocation() + "Output for frame Statement may have incorrect size information");
+ }
+ break;
+
case TENSOR:
//handle default and input arguments
setTensorDefault();
@@ -1840,9 +2088,8 @@
if (getOutput() instanceof IndexedIdentifier){
LOG.warn(this.printWarningLocation() + "Output for tensor Statement may have incorrect size information");
}
-
break;
-
+
case SQL:
//handle default and input arguments
setSqlDefault();
diff --git a/src/main/java/org/apache/sysds/parser/Expression.java b/src/main/java/org/apache/sysds/parser/Expression.java
index 059d093..e7b49ca 100644
--- a/src/main/java/org/apache/sysds/parser/Expression.java
+++ b/src/main/java/org/apache/sysds/parser/Expression.java
@@ -60,7 +60,7 @@
* Data operators.
*/
public enum DataOp {
- READ, WRITE, RAND, MATRIX, TENSOR, SQL, FEDERATED
+ READ, WRITE, RAND, MATRIX, FRAME, TENSOR, SQL, FEDERATED
}
/**
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index da86189..0a7e28d 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -285,6 +285,7 @@
String2CPInstructionType.put( DataGen.SINIT_OPCODE , CPType.StringInit);
String2CPInstructionType.put( DataGen.SAMPLE_OPCODE , CPType.Rand);
String2CPInstructionType.put( DataGen.TIME_OPCODE , CPType.Rand);
+ String2CPInstructionType.put( DataGen.FRAME_OPCODE , CPType.Rand);
String2CPInstructionType.put( "ctable", CPType.Ctable);
String2CPInstructionType.put( "ctableexpand", CPType.Ctable);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index b8fdfe8..4b77bff 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -284,7 +284,8 @@
String2SPInstructionType.put( DataGen.RAND_OPCODE , SPType.Rand);
String2SPInstructionType.put( DataGen.SEQ_OPCODE , SPType.Rand);
String2SPInstructionType.put( DataGen.SAMPLE_OPCODE, SPType.Rand);
-
+ String2SPInstructionType.put( DataGen.FRAME_OPCODE, SPType.Rand);
+
//ternary instruction opcodes
String2SPInstructionType.put( "ctable", SPType.Ctable);
String2SPInstructionType.put( "ctableexpand", SPType.Ctable);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
index 86b95a3..553577e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
@@ -29,12 +29,14 @@
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.DataGen;
import org.apache.sysds.lops.Lop;
+import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.RandomMatrixGenerator;
@@ -42,6 +44,8 @@
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
+import java.util.Arrays;
+import java.util.Random;
public class DataGenCPInstruction extends UnaryCPInstruction {
private static final Log LOG = LogFactory.getLog(DataGenCPInstruction.class.getName());
@@ -52,7 +56,7 @@
private boolean minMaxAreDoubles;
private final String minValueStr, maxValueStr;
private final double minValue, maxValue, sparsity;
- private final String pdf, pdfParams;
+ private final String pdf, pdfParams, frame_data, schema;
private final long seed;
private Long runtimeSeed;
@@ -67,10 +71,11 @@
private static final int SEED_POSITION_RAND = 8;
private static final int SEED_POSITION_SAMPLE = 4;
- private DataGenCPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out,
- CPOperand rows, CPOperand cols, CPOperand dims, int blen, String minValue, String maxValue, double sparsity, long seed,
- String probabilityDensityFunction, String pdfParams, int k,
- CPOperand seqFrom, CPOperand seqTo, CPOperand seqIncr, boolean replace, String opcode, String istr) {
+ private DataGenCPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out,
+ CPOperand rows, CPOperand cols, CPOperand dims, int blen, String minValue, String maxValue, double sparsity,
+ long seed, String probabilityDensityFunction, String pdfParams, int k, CPOperand seqFrom, CPOperand seqTo,
+ CPOperand seqIncr, boolean replace, String data, String schema, String opcode, String istr)
+ {
super(CPType.Rand, op, in, out, opcode, istr);
this.method = mthd;
this.rows = rows;
@@ -107,29 +112,38 @@
this.seq_to = seqTo;
this.seq_incr = seqIncr;
this.replace = replace;
+ this.frame_data = data;
+ this.schema = schema;
}
private DataGenCPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows, CPOperand cols,
CPOperand dims, int blen, String minValue, String maxValue, double sparsity, long seed,
- String probabilityDensityFunction, String pdfParams, int k, String opcode, String istr) {
+ String probabilityDensityFunction, String pdfParams, int k, String opcode, String istr) {
this(op, mthd, in, out, rows, cols, dims, blen, minValue, maxValue, sparsity, seed,
- probabilityDensityFunction, pdfParams, k, null, null, null, false, opcode, istr);
+ probabilityDensityFunction, pdfParams, k, null, null, null,
+ false, null, null, opcode, istr);
}
private DataGenCPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows, CPOperand cols,
CPOperand dims, int blen, String maxValue, boolean replace, long seed, String opcode, String istr) {
this(op, mthd, in, out, rows, cols, dims, blen, "0", maxValue, 1.0, seed,
- null, null, 1, null, null, null, replace, opcode, istr);
+ null, null, 1, null, null, null, replace, null, null, opcode, istr);
}
private DataGenCPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows, CPOperand cols,
CPOperand dims, int blen, CPOperand seqFrom, CPOperand seqTo, CPOperand seqIncr, String opcode, String istr) {
this(op, mthd, in, out, rows, cols, dims, blen, "0", "1", 1.0, -1,
- null, null, 1, seqFrom, seqTo, seqIncr, false, opcode, istr);
+ null, null, 1, seqFrom, seqTo, seqIncr, false, null, null, opcode, istr);
}
private DataGenCPInstruction(Operator op, OpOpDG mthd, CPOperand out, String opcode, String istr) {
this(op, mthd, null, out, null, null, null, 0, "0", "0", 0, 0,
- null, null, 1, null, null, null, false, opcode, istr);
+ null, null, 1, null, null, null, false, null, null, opcode, istr);
+ }
+
+ public DataGenCPInstruction(Operator op, OpOpDG method, CPOperand out, CPOperand rows, CPOperand cols, String data,
+ String schema, String opcode, String str) {
+ this(op, method, null, out, rows, cols, null, 0, "0", "0", 0, 0,
+ null, null, 1, null, null, null, false, data, schema, opcode, str);
}
public long getRows() {
@@ -217,6 +231,10 @@
// 1 operand: outvar
InstructionUtils.checkNumFields ( s, 1 );
}
+ else if ( opcode.equalsIgnoreCase(DataGen.FRAME_OPCODE) ) {
+ method = OpOpDG.FRAMEINIT;
+ InstructionUtils.checkNumFields ( s, 5 );
+ }
CPOperand out = new CPOperand(s[s.length-1]);
Operator op = null;
@@ -247,15 +265,23 @@
return new DataGenCPInstruction(op, method, null, out, rows, cols, dims, blen,
s[5 - missing], s[6 - missing], sparsity, seed, pdf, pdfParams, k, opcode, str);
}
- else if ( method == OpOpDG.SEQ)
+ else if ( method == OpOpDG.SEQ)
{
int blen = Integer.parseInt(s[3]);
CPOperand from = new CPOperand(s[4]);
CPOperand to = new CPOperand(s[5]);
CPOperand incr = new CPOperand(s[6]);
-
+
return new DataGenCPInstruction(op, method, null, out, null, null, null, blen, from, to, incr, opcode, str);
}
+ else if ( method == OpOpDG.FRAMEINIT)
+ {
+ String data = s[1];
+ CPOperand rows = new CPOperand(s[2]);
+ CPOperand cols = new CPOperand(s[3]);
+ String valueType = s[4];
+ return new DataGenCPInstruction(op, method, out, rows, cols, data, valueType, opcode, str);
+ }
else if ( method == OpOpDG.SAMPLE)
{
CPOperand rows = new CPOperand(s[2]);
@@ -282,6 +308,7 @@
MatrixBlock soresBlock = null;
TensorBlock tensorBlock = null;
ScalarObject soresScalar = null;
+ FrameBlock soresFrame = null;
//process specific datagen operator
if ( method == OpOpDG.RAND ) {
@@ -369,7 +396,41 @@
else if ( method == OpOpDG.TIME ) {
soresScalar = new IntObject(System.nanoTime());
}
-
+ else if(method == OpOpDG.FRAMEINIT)
+ {
+ int lrows = (int) ec.getScalarInput(rows).getLongValue();
+ int lcols = (int) ec.getScalarInput(cols).getLongValue();
+ String schemaValues[] = schema.split(DataExpression.DELIM_NA_STRING_SEP);
+ ValueType[] vt = schemaValues[0].equals(DataExpression.DEFAULT_SCHEMAPARAM) ?
+ UtilFunctions.nCopies(lcols, ValueType.STRING) :
+ UtilFunctions.stringToValueType(schemaValues);
+ int schemaLength = vt.length;
+ if(schemaLength != lcols)
+ throw new DMLRuntimeException("schema-dimension mismatch");
+
+ if(frame_data.equals("")) {
+ //TODO fix hard-coded seed, consistently with sparse frame init
+ soresFrame = UtilFunctions.generateRandomFrameBlock(lrows, lcols, vt, new Random(10));
+ }
+ else {
+ String[] data = frame_data.split(DataExpression.DELIM_NA_STRING_SEP);
+ if(data.length != schemaLength && data.length > 1)
+ throw new DMLRuntimeException("data values should be equal to number of columns," +
+ " or a single values for all columns");
+ if(data.length > 1) {
+ soresFrame = new FrameBlock(vt);
+ for(int i = 0; i < lrows; i++)
+ soresFrame.appendRow(data);
+ }
+ else {
+ soresFrame = new FrameBlock(vt);
+ String[] data1 = new String[lcols];
+ Arrays.fill(data1, frame_data);
+ for(int i = 0; i < lrows; i++)
+ soresFrame.appendRow(data1);
+ }
+ }
+ }
if( output.isMatrix() ) {
//guarded sparse block representation change
if( soresBlock.getInMemorySize() < OptimizerUtils.SAFE_REP_CHANGE_THRES )
@@ -386,6 +447,8 @@
}
else if( output.isScalar() )
ec.setScalarOutput(output.getName(), soresScalar);
+ else if (output.isFrame())
+ ec.setFrameOutput(output.getName(), soresFrame);
}
private static void checkValidDimensions(long rows, long cols) {
@@ -444,4 +507,5 @@
new CPOperand(ec.getScalarInput(op)).getLineageLiteral());
return inst;
}
+
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
index dac0aa5..3d0c2e5 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
@@ -50,6 +50,7 @@
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.DataGen;
import org.apache.sysds.lops.Lop;
+import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
@@ -62,6 +63,7 @@
import org.apache.sysds.runtime.instructions.spark.utils.RDDConverterUtils;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixCell;
@@ -90,7 +92,7 @@
private final double minValue, maxValue;
private final String minValueStr, maxValueStr;
private final double sparsity;
- private final String pdf, pdfParams;
+ private final String pdf, pdfParams, frame_data, schema;
private long seed = 0;
private final String dir;
private final CPOperand seq_from, seq_to, seq_incr;
@@ -104,9 +106,10 @@
private static final int SEED_POSITION_SAMPLE = 4;
private RandSPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows,
- CPOperand cols, CPOperand dims, int blen, String minValue, String maxValue,
- double sparsity, long seed, String dir, String probabilityDensityFunction, String pdfParams,
- CPOperand seqFrom, CPOperand seqTo, CPOperand seqIncr, boolean replace, String opcode, String istr)
+ CPOperand cols, CPOperand dims, int blen, String minValue, String maxValue, double sparsity,
+ long seed, String dir, String probabilityDensityFunction, String pdfParams, CPOperand seqFrom,
+ CPOperand seqTo, CPOperand seqIncr, boolean replace, String fdata,
+ String schema, String opcode, String istr)
{
super(SPType.Rand, op, in, out, opcode, istr);
this._method = mthd;
@@ -131,7 +134,6 @@
}
minDouble = -1;
maxDouble = -1;
- //minMaxAreDoubles = false;
}
this.minValue = minDouble;
this.maxValue = maxDouble;
@@ -144,6 +146,8 @@
this.seq_to = seqTo;
this.seq_incr = seqIncr;
this.replace = replace;
+ this.frame_data = fdata;
+ this.schema = schema;
}
private RandSPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows,
@@ -151,20 +155,30 @@
String dir, String probabilityDensityFunction, String pdfParams, String opcode, String istr)
{
this(op, mthd, in, out, rows, cols, dims, blen, minValue, maxValue, sparsity, seed, dir,
- probabilityDensityFunction, pdfParams, null, null, null, false, opcode, istr);
+ probabilityDensityFunction, pdfParams, null, null,
+ null, false, null, null, opcode, istr);
}
private RandSPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows,
CPOperand cols, CPOperand dims, int blen, CPOperand seqFrom, CPOperand seqTo,
CPOperand seqIncr, String opcode, String istr) {
this(op, mthd, in, out, rows, cols, dims, blen, "-1", "-1", -1, -1, null,
- null, null, seqFrom, seqTo, seqIncr, false, opcode, istr);
+ null, null, seqFrom, seqTo, seqIncr, false,
+ null, null, opcode, istr);
}
private RandSPInstruction(Operator op, OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows, CPOperand cols,
CPOperand dims, int blen, String maxValue, boolean replace, long seed, String opcode, String istr) {
this(op, mthd, in, out, rows, cols, dims, blen, "-1", maxValue, -1, seed, null,
- null, null, null, null, null, replace, opcode, istr);
+ null, null, null, null, null, replace,
+ null, null, opcode, istr);
+ }
+
+ private RandSPInstruction(Operator op, OpOpDG mthd, CPOperand out, CPOperand rows,
+ CPOperand cols, String fdata, String schema, String opcode, String istr) {
+ this(op, mthd, null, out, rows, cols, null, 0, "0", "1", 0,
+ 0, null,null, null, null, null,
+ null, false,fdata, schema, opcode, istr);
}
public long getRows() {
@@ -224,6 +238,11 @@
// 7 operands: range, size, replace, seed, blen, outvar
InstructionUtils.checkNumFields ( str, 6 );
}
+ else if ( opcode.equalsIgnoreCase(DataGen.FRAME_OPCODE) ) {
+ method = OpOpDG.FRAMEINIT;
+ InstructionUtils.checkNumFields ( str, 6 );
+ }
+
Operator op = null;
// output is specified by the last operand
@@ -264,8 +283,7 @@
return new RandSPInstruction(op, method, in, out, null,
null, null, blen, from, to, incr, opcode, str);
}
- else if ( method == OpOpDG.SAMPLE)
- {
+ else if ( method == OpOpDG.SAMPLE) {
String max = !s[1].contains(Lop.VARIABLE_NAME_PLACEHOLDER) ?
s[1] : "0";
CPOperand rows = new CPOperand(s[2]);
@@ -279,6 +297,14 @@
return new RandSPInstruction(op, method, null, out, rows, cols,
null, blen, max, replace, seed, opcode, str);
}
+ else if ( method == OpOpDG.FRAMEINIT) {
+ String data = s[1];
+ CPOperand rows = new CPOperand(s[2]);
+ CPOperand cols = new CPOperand(s[3]);
+ String valueType = s[4];
+ return new RandSPInstruction(op, method, out, rows, cols, data, valueType, opcode, str);
+ }
+
else
throw new DMLRuntimeException("Unrecognized data generation method: " + method);
}
@@ -292,17 +318,101 @@
case RAND: generateRandData(sec); break;
case SEQ: generateSequence(sec); break;
case SAMPLE: generateSample(sec); break;
+ case FRAMEINIT: generateFrame(sec); break;
default:
throw new DMLRuntimeException("Invalid datagen method: "+_method);
}
}
- private void generateRandData(SparkExecutionContext sec) {
- if (output.getDataType() == DataType.MATRIX) {
- generateRandDataMatrix(sec);
- } else {
- generateRandDataTensor(sec);
+ private void generateFrame(SparkExecutionContext sec) {
+ long lrows = sec.getScalarInput(rows).getLongValue();
+ long lcols = sec.getScalarInput(cols).getLongValue();
+ String data = frame_data;
+
+ //step 1: generate pseudo-random seed (because not specified)
+ long lSeed = generateRandomSeed();
+
+ if( LOG.isTraceEnabled() )
+ LOG.trace("Process RandSPInstruction frame with seed = "+lSeed+".");
+
+ //step 2: seed generation
+ JavaPairRDD<Long, Long> seedsRDD = null;
+ Well1024a bigrand = LibMatrixDatagen.setupSeedsForRand(lSeed);
+ double totalSize = OptimizerUtils.estimatePartitionedSizeExactSparsity( lrows, lcols, -1, 1);
+ double hdfsBlkSize = InfrastructureAnalyzer.getHDFSBlockSize();
+ int brlen = ConfigurationManager.getBlocksize();
+ DataCharacteristics tmp = new MatrixCharacteristics(lrows, lcols, brlen);
+
+ //a) in-memory seed rdd construction
+ if( tmp.getNumRowBlocks() < INMEMORY_NUMBLOCKS_THRESHOLD )
+ {
+ ArrayList<Tuple2<Long, Long>> seeds = new ArrayList<>();
+ for( long i=0; i<tmp.getNumRowBlocks(); i++ ) {
+ Long seedForBlock = bigrand.nextLong();
+ seeds.add(new Tuple2<>(i*brlen+1, seedForBlock));
+ }
+
+ //for load balancing: degree of parallelism such that ~128MB per partition
+ int numPartitions = (int) Math.max(Math.min(totalSize/hdfsBlkSize, tmp.getNumRowBlocks()), 1);
+
+ //create seeds rdd
+ seedsRDD = sec.getSparkContext().parallelizePairs(seeds, numPartitions);
}
+ //b) file-based seed rdd construction (for robustness wrt large number of blocks)
+ else
+ {
+ Path path = new Path(LibMatrixDatagen.generateUniqueSeedPath(dir));
+ PrintWriter pw = null;
+ try
+ {
+ FileSystem fs = IOUtilFunctions.getFileSystem(path);
+ pw = new PrintWriter(fs.create(path));
+ StringBuilder sb = new StringBuilder();
+ for( long i=0; i<tmp.getNumRowBlocks(); i++ ) {
+ sb.append(i*brlen+1);
+ sb.append(',');
+ sb.append(bigrand.nextLong());
+ pw.println(sb.toString());
+ sb.setLength(0);
+ }
+ }
+ catch( IOException ex ) {
+ throw new DMLRuntimeException(ex);
+ }
+ finally {
+ IOUtilFunctions.closeSilently(pw);
+ }
+
+ //for load balancing: degree of parallelism such that ~128MB per partition
+ int numPartitions = (int) Math.max(Math.min(totalSize/hdfsBlkSize, tmp.getNumRowBlocks()), 1);
+
+ //create seeds rdd
+ seedsRDD = sec.getSparkContext()
+ .textFile(path.toString(), numPartitions)
+ .mapToPair(new ExtractFrameSeedTuple());
+ }
+
+ //prepare input arguments
+ String schemaValues[] = schema.split(DataExpression.DELIM_NA_STRING_SEP);
+ ValueType[] vt = (schemaValues[0].equals(DataExpression.DEFAULT_SCHEMAPARAM)) ?
+ UtilFunctions.nCopies((int)lcols, ValueType.STRING) :
+ UtilFunctions.stringToValueType(schemaValues);
+ if(vt.length != lcols)
+ throw new DMLRuntimeException("schema-dimension mismatch: "+vt.length+" vs "+lcols);
+
+ //step 4: execute rand instruction over seed input
+ JavaPairRDD<Long, FrameBlock> out = seedsRDD
+ .mapToPair(new GenerateRandomFrameBlock(lrows, lcols, brlen, vt, data));
+
+ //step 5: output handling
+ sec.setRDDHandleForVariable(output.getName(), out);
+ }
+
+ private void generateRandData(SparkExecutionContext sec) {
+ if (output.getDataType() == DataType.MATRIX)
+ generateRandDataMatrix(sec);
+ else
+ generateRandDataTensor(sec);
//reset runtime seed (e.g., when executed in loop)
runtimeSeed = null;
}
@@ -800,6 +910,18 @@
}
}
+ private static class ExtractFrameSeedTuple implements PairFunction<String, Long, Long> {
+ private static final long serialVersionUID = 3973794676854157100L;
+
+ @Override
+ public Tuple2<Long, Long> call(String arg)
+ throws Exception
+ {
+ String[] parts = IOUtilFunctions.split(arg, ",");
+ Long ix = Long.parseLong(parts[0]);
+ return new Tuple2<>(ix,Long.parseLong(parts[1]));
+ }
+ }
private static class ExtractMatrixSeedTuple implements PairFunction<String, MatrixIndexes, Long> {
private static final long serialVersionUID = 3973794676854157101L;
@@ -836,7 +958,62 @@
return Double.parseDouble(arg);
}
}
+ private static class GenerateRandomFrameBlock implements PairFunction<Tuple2<Long, Long>, Long, FrameBlock>
+ {
+ private static final long serialVersionUID = 1616346120426470173L;
+ private final long _rlen;
+ private final long _clen;
+ private final int _brlen;
+ private final ValueType[] _schema;
+ private final String _data;
+
+ public GenerateRandomFrameBlock(long rlen, long clen, int brlen, ValueType[] schema, String fdata) {
+ _rlen = rlen;
+ _clen = clen;
+ _brlen = brlen;
+ _schema = schema;
+ _data = fdata;
+ }
+
+ @Override
+ public Tuple2<Long, FrameBlock> call(Tuple2<Long, Long> kv)
+ throws Exception
+ {
+ //compute local block size:
+ Long ix = kv._1();
+ long blockix = UtilFunctions.computeBlockIndex(ix, _brlen);
+ int lrlen = UtilFunctions.computeBlockSize(_rlen, blockix, _brlen);
+ //long seed = kv._2;
+
+ FrameBlock out = null;
+ if(_data.equals("")) {
+ //TODO fix hard-coded seed
+ out = UtilFunctions.generateRandomFrameBlock((int)_rlen, (int)_clen, _schema, new Random(10));
+ }
+ else {
+ String[] data = _data.split(DataExpression.DELIM_NA_STRING_SEP);
+ if(data.length != _schema.length && data.length > 1)
+ throw new DMLRuntimeException("data values should be equal "
+ + "to number of columns, or a single values for all columns");
+ if(data.length > 1) {
+ out = new FrameBlock(_schema);
+ for(int i = 0; i < lrlen; i++)
+ out.appendRow(data);
+ }
+ else {
+ out = new FrameBlock(_schema);
+ String[] data1 = new String[(int)_clen];
+ Arrays.fill(data1, _data);
+ for(int i = 0; i < lrlen; i++)
+ out.appendRow(data1);
+ }
+ }
+
+ return new Tuple2<>(kv._1, out);
+ }
+ }
+
private static class GenerateRandomBlock implements PairFunction<Tuple2<MatrixIndexes, Long>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID = 1616346120426470173L;
@@ -911,7 +1088,7 @@
@Override
public Tuple2<TensorIndexes, TensorBlock> call(Tuple2<TensorIndexes, Long> kv)
- throws Exception
+ throws Exception
{
//compute local block size:
TensorIndexes ix = kv._1();
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index f98cceb..a7fdaf4 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -105,7 +105,7 @@
public static int nextIntPow2( int in ) {
int expon = (in==0) ? 0 : 32-Integer.numberOfLeadingZeros(in-1);
long pow2 = pow(2, expon);
- return (int)((pow2>Integer.MAX_VALUE)?Integer.MAX_VALUE : pow2);
+ return (int)((pow2>Integer.MAX_VALUE)?Integer.MAX_VALUE : pow2);
}
public static long pow(int base, int exp) {
@@ -835,5 +835,74 @@
.map(DATE_FORMATS::get).orElseThrow(() -> new NullPointerException("Unknown date format."));
}
+ /**
+ * Generates a random FrameBlock with given parameters.
+ *
+ * @param rows frame rows
+ * @param cols frame cols
+ * @param schema frame schema
+ * @param random random number generator
+ * @return FrameBlock
+ */
+ public static FrameBlock generateRandomFrameBlock(int rows, int cols, ValueType[] schema, Random random){
+ String[] names = new String[cols];
+ for(int i = 0; i < cols; i++)
+ names[i] = schema[i].toString();
+ FrameBlock frameBlock = new FrameBlock(schema, names);
+ frameBlock.ensureAllocatedColumns(rows);
+ for(int row = 0; row < rows; row++)
+ for(int col = 0; col < cols; col++)
+ frameBlock.set(row, col, generateRandomValueFromValueType(schema[col], random));
+ return frameBlock;
+ }
+ /**
+ * Generates a random value for a given Value Type
+ *
+ * @param valueType the ValueType of which to generate the value
+ * @param random random number generator
+ * @return Object
+ */
+ public static Object generateRandomValueFromValueType(ValueType valueType, Random random){
+ switch (valueType){
+ case FP32: return random.nextFloat();
+ case FP64: return random.nextDouble();
+ case INT32: return random.nextInt();
+ case INT64: return random.nextLong();
+ case BOOLEAN: return random.nextBoolean();
+ case STRING:
+ return random.ints('a', 'z' + 1).limit(10)
+ .collect(StringBuilder::new, StringBuilder::appendCodePoint, StringBuilder::append)
+ .toString();
+ default:
+ return null;
+ }
+ }
+
+ /**
+ * Generates a ValueType array from a String array
+ *
+ * @param schemaValues the string schema of which to generate the ValueType
+ * @return ValueType[]
+ */
+ public static ValueType[] stringToValueType(String[] schemaValues) {
+ ValueType[] vt = new ValueType[schemaValues.length];
+ for(int i=0; i < schemaValues.length; i++) {
+ if(schemaValues[i].equalsIgnoreCase("STRING"))
+ vt[i] = ValueType.STRING;
+ else if (schemaValues[i].equalsIgnoreCase("FP64"))
+ vt[i] = ValueType.FP64;
+ else if (schemaValues[i].equalsIgnoreCase("FP32"))
+ vt[i] = ValueType.FP32;
+ else if (schemaValues[i].equalsIgnoreCase("INT64"))
+ vt[i] = ValueType.INT64;
+ else if (schemaValues[i].equalsIgnoreCase("INT32"))
+ vt[i] = ValueType.INT32;
+ else if (schemaValues[i].equalsIgnoreCase("BOOLEAN"))
+ vt[i] = ValueType.BOOLEAN;
+ else
+ throw new DMLRuntimeException("Invalid column schema. Allowed values are STRING, FP64, FP32, INT64, INT32 and Boolean");
+ }
+ return vt;
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java
index 4c4dc20..d4ddfe4 100644
--- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java
@@ -38,19 +38,24 @@
private final static double eps = 1e-3;
private final static int rows = 1700;
- //private final static double spDense = 0.99;
private final static double epsDBSCAN = 1;
private final static int minPts = 5;
@Override
- public void setUp() { addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"})); }
+ public void setUp() {
+ addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"}));
+ }
@Test
- public void testDBSCANDefaultCP() { runDBSCAN(true, ExecType.CP); }
+ public void testDBSCANDefaultCP() {
+ runDBSCAN(true, ExecType.CP);
+ }
@Test
- public void testDBSCANDefaultSP() { runDBSCAN(true, ExecType.SPARK); }
+ public void testDBSCANDefaultSP() {
+ runDBSCAN(true, ExecType.SPARK);
+ }
private void runDBSCAN(boolean defaultProb, ExecType instType)
{
diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameConstructorTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameConstructorTest.java
new file mode 100644
index 0000000..ef9e8a6
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameConstructorTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.frame;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.io.FrameReaderFactory;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.TestConfiguration;
+import org.junit.Test;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+
+import java.util.Random;
+
+public class FrameConstructorTest extends AutomatedTestBase {
+ private final static String TEST_DIR = "functions/frame/";
+ private final static String TEST_NAME = "FrameConstructorTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FrameConstructorTest.class.getSimpleName() + "/";
+
+ private final static int rows = 40;
+ private final static int cols = 4;
+
+ private final static ValueType[] schemaStrings1 = new ValueType[]{
+ ValueType.INT64, ValueType.STRING, ValueType.FP64, ValueType.BOOLEAN};
+
+ private final static ValueType[] schemaStrings2 = new ValueType[]{
+ ValueType.INT64, ValueType.STRING, ValueType.FP64, ValueType.STRING};
+
+ private enum TestType {
+ NAMED,
+ NO_SCHEMA,
+ RANDOM_DATA,
+ SINGLE_DATA
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B"}));
+ if (TEST_CACHE_ENABLED) {
+ setOutAndExpectedDeletionDisabled(true);
+ }
+ }
+
+ @Test
+ public void testFrameNamedParam() {
+ FrameBlock exp = createExpectedFrame(schemaStrings1, false);
+ runFrameTest(TestType.NAMED, exp, Types.ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void testFrameNamedParamSP() {
+ FrameBlock exp = createExpectedFrame(schemaStrings1, false);
+ runFrameTest(TestType.NAMED, exp, Types.ExecMode.SPARK);
+ }
+
+ @Test
+ public void testNoSchema() {
+ FrameBlock exp = createExpectedFrame(schemaStrings2, false);
+ runFrameTest(TestType.NO_SCHEMA, exp, Types.ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void testNoSchemaSP() {
+ FrameBlock exp = createExpectedFrame(schemaStrings2, false);
+ runFrameTest(TestType.NO_SCHEMA, exp, Types.ExecMode.SPARK);
+ }
+
+ @Test
+ public void testRandData() {
+ FrameBlock exp = UtilFunctions.generateRandomFrameBlock(rows, cols, schemaStrings1, new Random(10));
+ runFrameTest(TestType.RANDOM_DATA, exp, Types.ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void testRandDataSP() {
+ FrameBlock exp = UtilFunctions.generateRandomFrameBlock(rows, cols, schemaStrings1, new Random(10));
+ runFrameTest(TestType.RANDOM_DATA, exp, Types.ExecMode.SPARK);
+ }
+
+ @Test
+ public void testSingleData() {
+ FrameBlock exp = createExpectedFrame(schemaStrings1, true);
+ runFrameTest(TestType.SINGLE_DATA, exp, Types.ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void testSingleDataSP() {
+ FrameBlock exp = createExpectedFrame(schemaStrings1, true);
+ runFrameTest(TestType.SINGLE_DATA, exp, Types.ExecMode.SPARK);
+ }
+
+ private void runFrameTest(TestType type, FrameBlock expectedOutput, Types.ExecMode et) {
+ Types.ExecMode platformOld = setExecMode(et);
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+
+ try {
+ //setup testcase
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-explain", "-args", String.valueOf(type), output("F2")};
+
+ runTest(true, false, null, -1);
+ FrameBlock fB = FrameReaderFactory
+ .createFrameReader(Types.FileFormat.CSV)
+ .readFrameFromHDFS(output("F2"), rows, cols);
+ String[][] R1 = DataConverter.convertToStringFrame(expectedOutput);
+ String[][] R2 = DataConverter.convertToStringFrame(fB);
+ TestUtils.compareFrames(R1, R2, R1.length, R1[0].length);
+ }
+ catch(Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ finally {
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true;
+ OptimizerUtils.ALLOW_OPERATOR_FUSION = true;
+ }
+ }
+
+ private static FrameBlock createExpectedFrame(ValueType[] schema, boolean constant) {
+ FrameBlock exp = new FrameBlock(schema);
+ String[] out = constant ?
+ new String[]{"1", "1", "1", "1"} :
+ new String[]{"1", "abc", "2.5", "TRUE"};
+ for(int i=0; i<rows; i++)
+ exp.appendRow(out);
+ return exp;
+ }
+}
diff --git a/src/test/scripts/functions/frame/FrameConstructorTest.dml b/src/test/scripts/functions/frame/FrameConstructorTest.dml
new file mode 100644
index 0000000..8762d30
--- /dev/null
+++ b/src/test/scripts/functions/frame/FrameConstructorTest.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+print("param 1 "+$1)
+if($1 == "NAMED")
+ f1 = frame(data=["1", "abc", "2.5", "TRUE"], rows=40, cols=4, schema=["INT64", "STRING", "FP64", "BOOLEAN"]) # all named
+if($1 == "NO_SCHEMA")
+ f1 = frame(data=["1", "abc", "2.5", "TRUE"], rows=40, cols=4) # no schema
+if($1 == "RANDOM_DATA")
+ f1 = frame("", rows=40, cols=4, schema=["INT64", "STRING", "FP64", "BOOLEAN"]) # no data
+if($1 == "SINGLE_DATA")
+ f1 = frame(1, rows=40, cols=4, schema=["INT64", "STRING", "FP64", "BOOLEAN"]) # no data
+
+# f1 = frame(1, 4, 3) # unnamed parameters not working
+write(f1, $2, format="csv")