[SYSTEMDS-2633] Frame map function for evaluating lambda expressions
The built-in function uses the Janino compiler for run-time code
generation and compilation and accepts a frame and string lambda
expression (containing Java code) as input and execute the code in
string on frame and returns the output frame.
Closes #1034.
diff --git a/.github/workflows/functionsTests.yml b/.github/workflows/functionsTests.yml
index 7d0349a..d3930b1 100644
--- a/.github/workflows/functionsTests.yml
+++ b/.github/workflows/functionsTests.yml
@@ -38,6 +38,7 @@
tests: [
aggregate,
append,
+ binary.frame,
binary.matrix,
binary.matrix_full_cellwise,
binary.matrix_full_other,
diff --git a/docs/site/dml-language-reference.md b/docs/site/dml-language-reference.md
index 5d7a96b..f5d9700 100644
--- a/docs/site/dml-language-reference.md
+++ b/docs/site/dml-language-reference.md
@@ -60,6 +60,7 @@
* [Indexing Frames](#indexing-frames)
* [Casting Frames](#casting-frames)
* [Transforming Frames](#transforming-frames)
+ * [Processing Frames](#processing-frames)
* [Modules](#modules)
* [Reserved Keywords](#reserved-keywords)
@@ -2023,6 +2024,47 @@
2.000 2.000 1.000 2.000 2.500 2.000 2.000 1.000 889.000
4.000 1.000 1.000 3.000 1.500 1.000 1.000 1.000 628.000
+### Processing Frames
+
+Built-In functions <code>dml_map()</code> is supported for frames to execute any arbitrary Java code on a frame.
+
+**Table F5**: Frame dml_map Built-In Function
+
+Function | Description | Parameters | Example
+-------- | ----------- | ---------- | -------
+dml_map() | It will execute the given java code on a frame (column-vector).| Input: (X <frame>, y <String>) <br/>Output: <frame>. <br/> X is a frame and y is a String containing the Java code to be executed on frame X. where X is a column vector. | X = read("file1", data_type="frame", rows=2, cols=3, format="binary") <br/> y = "Java code" <br/> Z = dml_map(X, y) <br/> # Dimensions of Z = Dimensions of X; <br/> example: Z = dml_map(X, "x.charAt(2)")
+Example let X =
+
+ ##### FRAME: nrow = 10, ncol = 1 <br/>
+ # C1
+ # STRING
+ west
+ south
+ north
+ east
+ south
+ north
+ north
+ west
+ west
+ east
+
+Z = dml_map(X, "x.toUpperCase()") <br/>
+print(toString(Z))
+ ##### FRAME: nrow = 10, ncol = 1 <br/>
+ # C1
+ # STRING
+ WEST
+ SOUTH
+ NORTH
+ EAST
+ SOUTH
+ NORTH
+ NORTH
+ WEST
+ WEST
+ EAST
+
* * *
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index 9e14f50..20fad72 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -93,7 +93,7 @@
DROP_INVALID_LENGTH("dropInvalidLength", false),
EIGEN("eigen", false, ReturnType.MULTI_RETURN),
EXISTS("exists", false),
- ExecutePipeline("executePipeline", true),
+ EXECUTE_PIPELINE("executePipeline", true),
EXP("exp", false),
EVAL("eval", false),
FLOOR("floor", false),
@@ -127,6 +127,7 @@
LSTM("lstm", false, ReturnType.MULTI_RETURN),
LSTM_BACKWARD("lstm_backward", false, ReturnType.MULTI_RETURN),
LU("lu", false, ReturnType.MULTI_RETURN),
+ MAP("map", false),
MEAN("mean", "avg", false),
MICE("mice", true),
MIN("min", "pmin", false),
diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index 978c644..651214e 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -267,12 +267,12 @@
public enum OpOp2 {
AND(true), BITWAND(true), BITWOR(true), BITWSHIFTL(true), BITWSHIFTR(true),
BITWXOR(true), CBIND(false), CONCAT(false), COV(false), DIV(true),
- DROP_INVALID_TYPE(false), DROP_INVALID_LENGTH(false),
- EQUAL(true), GREATER(true), GREATEREQUAL(true),
- INTDIV(true), INTERQUANTILE(false), IQM(false), LESS(true), LESSEQUAL(true),
- LOG(true), MAX(true), MEDIAN(false), MIN(true), MINUS(true), MODULUS(true),
- MOMENT(false), MULT(true), NOTEQUAL(true), OR(true), PLUS(true), POW(true),
- PRINT(false), QUANTILE(false), SOLVE(false), RBIND(false), XOR(true),
+ DROP_INVALID_TYPE(false), DROP_INVALID_LENGTH(false), EQUAL(true), GREATER(true),
+ GREATEREQUAL(true), INTDIV(true), INTERQUANTILE(false), IQM(false), LESS(true),
+ LESSEQUAL(true), LOG(true), MAP(false), MAX(true), MEDIAN(false), MIN(true),
+ MINUS(true), MODULUS(true), MOMENT(false), MULT(true), NOTEQUAL(true), OR(true),
+ PLUS(true), POW(true), PRINT(false), QUANTILE(false), SOLVE(false), RBIND(false),
+ XOR(true),
//fused ML-specific operators for performance
MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=))
LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5)
@@ -317,6 +317,7 @@
case BITWSHIFTR: return "bitwShiftR";
case DROP_INVALID_TYPE: return "dropInvalidType";
case DROP_INVALID_LENGTH: return "dropInvalidLength";
+ case MAP: return "dml_map";
default: return name().toLowerCase();
}
}
@@ -350,6 +351,7 @@
case "bitwShiftR": return BITWSHIFTR;
case "dropInvalidType": return DROP_INVALID_TYPE;
case "dropInvalidLength": return DROP_INVALID_LENGTH;
+ case "map": return MAP;
default: return valueOf(opcode.toUpperCase());
}
}
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 586d675..cc5d58d 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -1054,7 +1054,10 @@
{
if( !(that instanceof BinaryOp) )
return false;
-
+
+ if(op == OpOp2.MAP)
+ return false; // custom UDFs
+
BinaryOp that2 = (BinaryOp)that;
return ( op == that2.op
&& outer == that2.outer
diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index 7db411c..088e740 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -1549,6 +1549,16 @@
output.setValueType(id.getValueType());
break;
+ case MAP:
+ checkNumParameters(2);
+ checkMatrixFrameParam(getFirstExpr());
+ checkScalarParam(getSecondExpr());
+ output.setDataType(DataType.FRAME);
+ output.setDimensions(id.getDim1(), 1);
+ output.setBlocksize (id.getBlocksize());
+ output.setValueType(ValueType.STRING);
+ break;
+
default:
if( isMathFunction() ) {
checkMathFunctionParam();
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 4747bfe..c2e0f95 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2536,6 +2536,7 @@
break;
case DROP_INVALID_TYPE:
case DROP_INVALID_LENGTH:
+ case MAP:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp2.valueOf(source.getOpCode().name()), expr, expr2);
break;
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
index 5d66c15..3022c5d 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
@@ -50,7 +50,7 @@
public enum BuiltinCode { SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN,
MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX,
STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST,
- TYPEOF, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, DROP_INVALID_LENGTH,
+ TYPEOF, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, DROP_INVALID_LENGTH, DML_MAP,
COUNT_DISTINCT, COUNT_DISTINCT_APPROX}
@@ -107,6 +107,7 @@
String2BuiltinCode.put( "isinf", BuiltinCode.ISINF);
String2BuiltinCode.put( "dropInvalidType", BuiltinCode.DROP_INVALID_TYPE);
String2BuiltinCode.put( "dropInvalidLength", BuiltinCode.DROP_INVALID_LENGTH);
+ String2BuiltinCode.put( "dml_map", BuiltinCode.DML_MAP);
}
private Builtin(BuiltinCode bf) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index d280b16..f8a2636 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -154,6 +154,7 @@
String2CPInstructionType.put( "min" , CPType.Binary);
String2CPInstructionType.put( "dropInvalidType" , CPType.Binary);
String2CPInstructionType.put( "dropInvalidLength" , CPType.Binary);
+ String2CPInstructionType.put( "dml_map" , CPType.Binary);
String2CPInstructionType.put( "nmax", CPType.BuiltinNary);
String2CPInstructionType.put( "nmin", CPType.BuiltinNary);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index b4104e1..42bf6f9 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -179,6 +179,7 @@
String2SPInstructionType.put( "map-*", SPType.Binary);
String2SPInstructionType.put( "dropInvalidType", SPType.Binary);
String2SPInstructionType.put( "mapdropInvalidLength", SPType.Binary);
+ String2SPInstructionType.put( "dml_map", SPType.Binary);
// Relational Instruction Opcodes
String2SPInstructionType.put( "==" , SPType.Binary);
String2SPInstructionType.put( "!=" , SPType.Binary);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
index bd87a15..2f0aad4 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
@@ -57,6 +57,8 @@
return new BinaryFrameFrameCPInstruction(operator, in1, in2, out, opcode, str);
else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.MATRIX)
return new BinaryFrameMatrixCPInstruction(operator, in1, in2, out, opcode, str);
+ else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.SCALAR)
+ return new BinaryFrameScalarCPInstruction(operator, in1, in2, out, opcode, str);
else
return new BinaryMatrixScalarCPInstruction(operator, in1, in2, out, opcode, str);
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java
new file mode 100644
index 0000000..bcf7cb5
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.cp;
+
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class BinaryFrameScalarCPInstruction extends BinaryCPInstruction
+{
+ protected BinaryFrameScalarCPInstruction(Operator op, CPOperand in1,
+ CPOperand in2, CPOperand out, String opcode, String istr) {
+ super(CPType.Binary, op, in1, in2, out, opcode, istr);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ // get input frames
+ FrameBlock inBlock = ec.getFrameInput(input1.getName());
+ String stringExpression = ec.getScalarInput(input2).getStringValue();
+ //compute results
+ FrameBlock outBlock = inBlock.map(stringExpression);
+ // Attach result frame with FrameBlock associated with output_name
+ ec.setFrameOutput(output.getName(), outBlock);
+ // Release the memory occupied by input frames
+ ec.releaseFrameInput(input1.getName());
+ }
+}
+
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java
index 6966178..deb8fb4 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java
@@ -22,11 +22,8 @@
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
-import org.apache.sysds.common.Types;
-import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -34,25 +31,11 @@
import scala.Tuple2;
public class BinaryFrameFrameSPInstruction extends BinarySPInstruction {
- protected BinaryFrameFrameSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
+ protected BinaryFrameFrameSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out,
+ String opcode, String istr) {
super(SPType.Binary, op, in1, in2, out, opcode, istr);
}
- public static BinarySPInstruction parseInstruction ( String str) {
- String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
- InstructionUtils.checkNumFields (parts, 3);
- String opcode = parts[0];
- CPOperand in1 = new CPOperand(parts[1]);
- CPOperand in2 = new CPOperand(parts[2]);
- CPOperand out = new CPOperand(parts[3]);
- Types.DataType dt1 = in1.getDataType();
- Types.DataType dt2 = in2.getDataType();
- Operator operator = InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2);
- if(dt1 == Types.DataType.FRAME && dt2 == Types.DataType.FRAME)
- return new BinaryFrameFrameSPInstruction(operator, in1, in2, out, opcode, str);
- else
- throw new DMLRuntimeException("Frame binary operation not yet implemented for frame-scalar, or frame-matrix");
- }
@Override
public void processInstruction(ExecutionContext ec) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameScalarSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameScalarSPInstruction.java
new file mode 100644
index 0000000..a395c16
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameScalarSPInstruction.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.spark;
+
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class BinaryFrameScalarSPInstruction extends BinarySPInstruction {
+ protected BinaryFrameScalarSPInstruction (Operator op, CPOperand in1, CPOperand in2, CPOperand out,
+ String opcode, String istr) {
+ super(SPType.Binary, op, in1, in2, out, opcode, istr);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ SparkExecutionContext sec = (SparkExecutionContext)ec;
+
+ // Get input RDDs
+ JavaPairRDD<Long, FrameBlock> in1 = sec.getFrameBinaryBlockRDDHandleForVariable(input1.getName());
+ String expression = sec.getScalarInput(input2).getStringValue();
+
+ // Create local compiled functions (once) and execute on RDD
+ JavaPairRDD<Long, FrameBlock> out = in1.mapValues(new RDDStringProcessing(expression));
+
+ sec.setRDDHandleForVariable(output.getName(), out);
+ sec.addLineageRDD(output.getName(), input1.getName());
+ }
+
+ private static class RDDStringProcessing implements Function<FrameBlock,FrameBlock> {
+ private static final long serialVersionUID = 5850400295183766400L;
+
+ private String _expr = null;
+
+ public RDDStringProcessing(String expr) {
+ _expr = expr;
+ }
+
+ @Override
+ public FrameBlock call(FrameBlock arg0) throws Exception {
+ return arg0.map(_expr);
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
index ee96dc9..f4f98dc 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
@@ -109,7 +109,9 @@
}
else if( dt1 == DataType.FRAME || dt2 == DataType.FRAME ) {
if(dt1 == DataType.FRAME && dt2 == DataType.FRAME)
- return BinaryFrameFrameSPInstruction.parseInstruction(str);
+ return new BinaryFrameFrameSPInstruction(operator, in1, in2, out, opcode, str);
+ if(dt1 == DataType.FRAME && dt2 == DataType.SCALAR)
+ return new BinaryFrameScalarSPInstruction(operator, in1, in2, out, opcode, str);
}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index 8a094d0..d5ee5a7 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -47,7 +47,9 @@
import org.apache.sysds.api.DMLException;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.functionobjects.ValueComparisonFunction;
import org.apache.sysds.runtime.instructions.cp.*;
import org.apache.sysds.runtime.io.IOUtilFunctions;
@@ -59,9 +61,9 @@
@SuppressWarnings({"rawtypes","unchecked"}) //allow generic native arrays
public class FrameBlock implements CacheBlock, Externalizable {
- private static final Log LOG = LogFactory.getLog(FrameBlock.class.getName());
-
private static final long serialVersionUID = -3993450030207130665L;
+ private static final Log LOG = LogFactory.getLog(FrameBlock.class.getName());
+ private static final IDSequence CLASS_ID = new IDSequence();
public static final int BUFFER_SIZE = 1 * 1000 * 1000; //1M elements, size of default matrix block
@@ -2078,4 +2080,56 @@
mergedFrame.appendRow(rowTemp1);
return mergedFrame;
}
+
+ public FrameBlock map(String lambdaExpr) {
+ return map(getCompiledFunction(lambdaExpr));
+ }
+
+ public FrameBlock map(FrameMapFunction lambdaExpr) {
+ // Prepare temporary output array
+ String[][] output = new String[getNumRows()][getNumColumns()];
+
+ // Execute map function on all cells
+ for(int j=0; j<getNumColumns(); j++) {
+ Array input = getColumn(j);
+ for (int i = 0; i < input._size; i++)
+ if(input.get(i) != null)
+ output[i][j] = lambdaExpr.apply(String.valueOf(input.get(i)));
+ }
+
+ return new FrameBlock(UtilFunctions.nCopies(getNumColumns(), ValueType.STRING), output);
+ }
+
+ public static FrameMapFunction getCompiledFunction(String lambdaExpr) {
+ // split lambda expression
+ String[] parts = lambdaExpr.split("->");
+ if( parts.length != 2 )
+ throw new DMLRuntimeException("Unsupported lambda expression: "+lambdaExpr);
+ String varname = parts[0].trim();
+ String expr = parts[1].trim();
+
+ // construct class code
+ String cname = "StringProcessing"+CLASS_ID.getNextID();
+ StringBuilder sb = new StringBuilder();
+ sb.append("import org.apache.sysds.runtime.util.UtilFunctions;\n");
+ sb.append("import org.apache.sysds.runtime.matrix.data.FrameBlock.FrameMapFunction;\n");
+ sb.append("public class "+cname+" extends FrameMapFunction {\n");
+ sb.append("@Override\n");
+ sb.append("public String apply(String "+varname+") {\n");
+ sb.append(" return String.valueOf("+expr+"); }}\n");
+
+ // compile class, and create FrameMapFunction object
+ try {
+ return (FrameMapFunction) CodegenUtils
+ .compileClass(cname, sb.toString()).newInstance();
+ }
+ catch(InstantiationException | IllegalAccessException e) {
+ throw new DMLRuntimeException("Failed to compile FrameMapFunction.", e);
+ }
+ }
+
+ public static abstract class FrameMapFunction implements Serializable {
+ private static final long serialVersionUID = -8398572153616520873L;
+ public abstract String apply(String input);
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index 77149d1..bb98f5c 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -19,12 +19,9 @@
package org.apache.sysds.runtime.util;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.BitSet;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
+import java.text.ParseException;
+import java.text.SimpleDateFormat;
+import java.util.*;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.math3.random.RandomDataGenerator;
@@ -794,4 +791,49 @@
break;
}
}
+
+ private static final Map<String, String> DATE_FORMATS = new HashMap<String, String>() {
+ private static final long serialVersionUID = 6826162458614520846L; {
+ put("^\\d{4}-\\d{1,2}-\\d{1,2}\\s\\d{1,2}:\\d{2}:\\d{2}$", "yyyy-MM-dd HH:mm:ss");
+ put("^\\d{1,2}-\\d{1,2}-\\d{4}\\s\\d{1,2}:\\d{2}:\\d{2}$", "dd-MM-yyyy HH:mm:ss");
+ put("^\\d{1,2}/\\d{1,2}/\\d{4}\\s\\d{1,2}:\\d{2}:\\d{2}$", "MM/dd/yyyy HH:mm:ss");
+ put("^\\d{4}/\\d{1,2}/\\d{1,2}\\s\\d{1,2}:\\d{2}:\\d{2}$", "yyyy/MM/dd HH:mm:ss");
+ put("^\\d{1,2}\\s[a-z]{3}\\s\\d{4}\\s\\d{1,2}:\\d{2}:\\d{2}$", "dd MMM yyyy HH:mm:ss");
+ put("^\\d{1,2}\\s[a-z]{4,}\\s\\d{4}\\s\\d{1,2}:\\d{2}:\\d{2}$", "dd MMMM yyyy HH:mm:ss");
+ put("^\\d{8}$", "yyyyMMdd");
+ put("^\\d{1,2}-\\d{1,2}-\\d{4}$", "dd-MM-yyyy");
+ put("^\\d{4}-\\d{1,2}-\\d{1,2}$", "yyyy-MM-dd");
+ put("^\\d{1,2}/\\d{1,2}/\\d{4}$", "MM/dd/yyyy");
+ put("^\\d{4}/\\d{1,2}/\\d{1,2}$", "yyyy/MM/dd");
+ put("^\\d{1,2}\\s[a-z]{3}\\s\\d{4}$", "dd MMM yyyy");
+ put("^\\d{1,2}\\s[a-z]{4,}\\s\\d{4}$", "dd MMMM yyyy");
+ put("^\\d{12}$", "yyyyMMddHHmm");
+ put("^\\d{8}\\s\\d{4}$", "yyyyMMdd HHmm");
+ put("^\\d{1,2}-\\d{1,2}-\\d{4}\\s\\d{1,2}:\\d{2}$", "dd-MM-yyyy HH:mm");
+ put("^\\d{4}-\\d{1,2}-\\d{1,2}\\s\\d{1,2}:\\d{2}$", "yyyy-MM-dd HH:mm");
+ put("^\\d{1,2}/\\d{1,2}/\\d{4}\\s\\d{1,2}:\\d{2}$", "MM/dd/yyyy HH:mm");
+ put("^\\d{4}/\\d{1,2}/\\d{1,2}\\s\\d{1,2}:\\d{2}$", "yyyy/MM/dd HH:mm");
+ put("^\\d{1,2}\\s[a-z]{3}\\s\\d{4}\\s\\d{1,2}:\\d{2}$", "dd MMM yyyy HH:mm");
+ put("^\\d{1,2}\\s[a-z]{4,}\\s\\d{4}\\s\\d{1,2}:\\d{2}$", "dd MMMM yyyy HH:mm");
+ put("^\\d{14}$", "yyyyMMddHHmmss");
+ put("^\\d{8}\\s\\d{6}$", "yyyyMMdd HHmmss");
+ }};
+
+ public static long toMillis (String dateString) {
+ long value = 0;
+ try {
+ value = new SimpleDateFormat(getDateFormat(dateString)).parse(dateString).getTime();
+ }
+ catch(ParseException e) {
+ throw new DMLRuntimeException(e);
+ }
+ return value ;
+ }
+
+ private static String getDateFormat (String dateString) {
+ return DATE_FORMATS.keySet().parallelStream().filter(e -> dateString.toLowerCase().matches(e)).findFirst()
+ .map(DATE_FORMATS::get).orElseThrow(() -> new NullPointerException("Unknown date format."));
+ }
+
+
}
diff --git a/src/test/java/org/apache/sysds/test/functions/binary/frame/FrameMapTest.java b/src/test/java/org/apache/sysds/test/functions/binary/frame/FrameMapTest.java
new file mode 100644
index 0000000..db5e0ef
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/binary/frame/FrameMapTest.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.binary.frame;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.io.FrameWriterFactory;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class FrameMapTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "dmlMap";
+ private final static String TEST_DIR = "functions/binary/frame/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + FrameMapTest.class.getSimpleName() + "/";
+
+ private final static int rows = 100005;
+ private final static int rows2 = 10;
+ private final static Types.ValueType[] schemaStrings1 = {Types.ValueType.STRING};
+
+ static enum TestType {
+ SPLIT,
+ CHAR_AT,
+ REPLACE,
+ UPPER_CASE,
+ DATE_UTILS
+ }
+ @BeforeClass
+ public static void init() {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+ @AfterClass
+ public static void cleanUp() {
+ if (TEST_CACHE_ENABLED) {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"D"}));
+ if (TEST_CACHE_ENABLED) {
+ setOutAndExpectedDeletionDisabled(true);
+ }
+ }
+
+ @Test
+ public void testUpperCaseOperationCP() {
+ runDmlMapTest("x -> x.toUpperCase()", TestType.UPPER_CASE, ExecType.CP);
+ }
+
+ @Test
+ public void testSplitOperationCP() {
+ runDmlMapTest("x -> x.split(\"r\")[1]", TestType.SPLIT, ExecType.CP);
+ }
+
+ @Test
+ public void testChatAtOperationCP() {
+ runDmlMapTest("y -> y.charAt(0)", TestType.CHAR_AT, ExecType.CP);
+ }
+
+ @Test
+ public void testReplaceOperationCP() {
+ runDmlMapTest("x -> x.replaceAll(\"[a-zA-Z]\", \"\")", TestType.REPLACE, ExecType.CP);
+ }
+
+ @Test
+ public void testDateUtilsOperationCP() {
+ runDmlMapTest("x -> UtilFunctions.toMillis(x)", TestType.DATE_UTILS, ExecType.CP);
+ }
+
+ @Test
+ public void testSplitOperationSP() {
+ runDmlMapTest("x -> x.split(\"r\")[1]", TestType.SPLIT, ExecType.SPARK);
+ }
+
+ @Test
+ public void testChatAtOperationSP() {
+ runDmlMapTest("y -> y.charAt(0)", TestType.CHAR_AT, ExecType.SPARK);
+ }
+
+ @Test
+ public void testDateUtilsOperationSp() {
+ runDmlMapTest("x -> UtilFunctions.toMillis(x)", TestType.DATE_UTILS, ExecType.SPARK);
+ }
+
+ private void runDmlMapTest( String expression, TestType type, ExecType et)
+ {
+ Types.ExecMode platformOld = setExecMode(et);
+
+ try {
+ getAndLoadTestConfiguration(TEST_NAME);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] { "-stats","-args", input("A"), expression,
+ output("O"), output("I")};
+
+ if(type == TestType.DATE_UTILS) {
+ String[][] date = new String[rows2][1];
+ for(int i = 0; i<rows2; i++)
+ date[i][0] = (i%30)+"/"+(i%12)+"/200"+(i%20);
+ FrameWriterFactory.createFrameWriter(FileFormat.CSV).
+ writeFrameToHDFS(new FrameBlock(schemaStrings1, date), input("A"), rows2, 1);
+ }
+ else {
+ double[][] A = getRandomMatrix(rows, 1, 0, 1, 1, 2);
+ writeInputFrameWithMTD("A", A, true, schemaStrings1, FileFormat.CSV);
+ }
+
+ setOutputBuffering(false);
+ runTest(true, false, null, -1);
+
+ FrameBlock outputFrame = readDMLFrameFromHDFS("O", FileFormat.CSV);
+ FrameBlock inputFrame = readDMLFrameFromHDFS("I", FileFormat.CSV);
+
+ String[] output = (String[])outputFrame.getColumnData(0);
+ String[] input = (String[])inputFrame.getColumnData(0);
+
+ switch (type) {
+ case SPLIT:
+ for(int i = 0; i<input.length; i++)
+ TestUtils.compareScalars(input[i].split("r")[1], output[i]);
+ break;
+ case CHAR_AT:
+ for(int i = 0; i<input.length; i++)
+ TestUtils.compareScalars(String.valueOf(input[i].charAt(0)), output[i]);
+ break;
+ case REPLACE:
+ for(int i = 0; i<input.length; i++)
+ TestUtils.compareScalars(String.valueOf(input[i].
+ replaceAll("[a-zA-Z]", "")), output[i]);
+ break;
+ case UPPER_CASE:
+ for(int i = 0; i<input.length; i++)
+ TestUtils.compareScalars(String.valueOf(input[i].toUpperCase()), output[i]);
+ break;
+ case DATE_UTILS:
+ for(int i =0; i<input.length; i++)
+ TestUtils.compareScalars(String.valueOf(UtilFunctions.toMillis(input[i])), output[i]);
+ break;
+ }
+ }
+ catch (Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+}
diff --git a/src/test/scripts/functions/binary/frame/dmlMap.dml b/src/test/scripts/functions/binary/frame/dmlMap.dml
new file mode 100644
index 0000000..482c37e
--- /dev/null
+++ b/src/test/scripts/functions/binary/frame/dmlMap.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# input: 1) frame, 2) lamba expression to execute for each non-null cell
+# output: frame of string columns
+
+# Examples:
+# map(X, "x -> x.split(\"r\")[1]")
+# map(X, "x -> x.charAt(2)")
+# map(X, "y -> UtilFunctions.toMillis(y)")
+
+X = read($1, data_type = "frame", format = "csv", header = FALSE)
+# column vector and string operation
+Y = map(X, $2)
+write(Y, $3, format="csv")
+write(X, $4, format="csv")
\ No newline at end of file