[SYSTEMDS-2617] New builtin for obtaining frame column names
New builtin colnames(X) for obtaining a single-row frame holding the
column names by position.
Closes #1020.
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index 22134ea..cc5b12b 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -64,6 +64,7 @@
COLMAX("colMaxs", false),
COLMEAN("colMeans", false),
COLMIN("colMins", false),
+ COLNAMES("colnames", false),
COLPROD("colProds", false),
COLSD("colSds", false),
COLSUM("colSums", false),
@@ -182,7 +183,7 @@
TANH("tanh", false),
TRACE("trace", false),
TO_ONE_HOT("toOneHot", true),
- TYPEOF("typeOf", false),
+ TYPEOF("typeof", false),
COUNT_DISTINCT("countDistinct",false),
COUNT_DISTINCT_APPROX("countDistinctApprox",false),
VAR("var", false),
diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index 92027a5..978c644 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -195,7 +195,7 @@
ABS, ACOS, ASIN, ASSERT, ATAN, CAST_AS_SCALAR, CAST_AS_MATRIX,
CAST_AS_FRAME, CAST_AS_DOUBLE, CAST_AS_INT, CAST_AS_BOOLEAN,
CEIL, CHOLESKY, COS, COSH, CUMMAX, CUMMIN, CUMPROD, CUMSUM,
- CUMSUMPROD, DETECTSCHEMA, EIGEN, EXISTS, EXP, FLOOR, INVERSE,
+ CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, INVERSE,
IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
MEDIAN, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, SVD,
TAN, TANH, TYPEOF,
@@ -231,6 +231,7 @@
case CUMPROD: return "ucum*";
case CUMSUM: return "ucumk+";
case CUMSUMPROD: return "ucumk+*";
+ case COLNAMES: return "colnames";
case DETECTSCHEMA: return "detectSchema";
case MULT2: return "*2";
case NOT: return "!";
diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java
index f9b46a8..6da0e32 100644
--- a/src/main/java/org/apache/sysds/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java
@@ -539,7 +539,8 @@
setDim1(input.getDim1());
setDim2(1);
}
- else if(_op == OpOp1.TYPEOF || _op == OpOp1.DETECTSCHEMA) {
+ else if(_op == OpOp1.TYPEOF || _op == OpOp1.DETECTSCHEMA || _op == OpOp1.COLNAMES) {
+ //TODO theses three builtins should rather be moved to unary aggregates
setDim1(1);
setDim2(input.getDim2());
}
diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index d3966e2..7db411c 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -716,6 +716,7 @@
break;
case TYPEOF:
case DETECTSCHEMA:
+ case COLNAMES:
checkNumParameters(1);
checkMatrixFrameParam(getFirstExpr());
output.setDataType(DataType.FRAME);
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index f84f469..4747bfe 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2656,8 +2656,9 @@
case CHOLESKY:
case TYPEOF:
case DETECTSCHEMA:
- currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(),
- OpOp1.valueOf(source.getOpCode().name()), expr);
+ case COLNAMES:
+ currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(),
+ target.getValueType(), OpOp1.valueOf(source.getOpCode().name()), expr);
break;
case OUTER:
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 30ec6bd..d280b16 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -190,6 +190,7 @@
String2CPInstructionType.put( "sigmoid", CPType.Unary);
String2CPInstructionType.put( "typeOf", CPType.Unary);
String2CPInstructionType.put( "detectSchema", CPType.Unary);
+ String2CPInstructionType.put( "colnames", CPType.Unary);
String2CPInstructionType.put( "isna", CPType.Unary);
String2CPInstructionType.put( "isnan", CPType.Unary);
String2CPInstructionType.put( "isinf", CPType.Unary);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index a53acb9..b4104e1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -93,10 +93,10 @@
String2SPInstructionType = new HashMap<>();
//unary aggregate operators
- String2SPInstructionType.put( "uak+" , SPType.AggregateUnary);
+ String2SPInstructionType.put( "uak+" , SPType.AggregateUnary);
String2SPInstructionType.put( "uark+" , SPType.AggregateUnary);
String2SPInstructionType.put( "uack+" , SPType.AggregateUnary);
- String2SPInstructionType.put( "uasqk+" , SPType.AggregateUnary);
+ String2SPInstructionType.put( "uasqk+" , SPType.AggregateUnary);
String2SPInstructionType.put( "uarsqk+" , SPType.AggregateUnary);
String2SPInstructionType.put( "uacsqk+" , SPType.AggregateUnary);
String2SPInstructionType.put( "uamean" , SPType.AggregateUnary);
@@ -107,7 +107,7 @@
String2SPInstructionType.put( "uacvar" , SPType.AggregateUnary);
String2SPInstructionType.put( "uamax" , SPType.AggregateUnary);
String2SPInstructionType.put( "uarmax" , SPType.AggregateUnary);
- String2SPInstructionType.put( "uarimax", SPType.AggregateUnary);
+ String2SPInstructionType.put( "uarimax" , SPType.AggregateUnary);
String2SPInstructionType.put( "uacmax" , SPType.AggregateUnary);
String2SPInstructionType.put( "uamin" , SPType.AggregateUnary);
String2SPInstructionType.put( "uarmin" , SPType.AggregateUnary);
@@ -127,7 +127,7 @@
String2SPInstructionType.put( "mapmmchain" , SPType.MAPMMCHAIN);
String2SPInstructionType.put( "tsmm" , SPType.TSMM); //single-pass tsmm
String2SPInstructionType.put( "tsmm2" , SPType.TSMM2); //multi-pass tsmm
- String2SPInstructionType.put( "cpmm" , SPType.CPMM);
+ String2SPInstructionType.put( "cpmm" , SPType.CPMM);
String2SPInstructionType.put( "rmm" , SPType.RMM);
String2SPInstructionType.put( "pmm" , SPType.PMM);
String2SPInstructionType.put( "zipmm" , SPType.ZIPMM);
@@ -141,42 +141,42 @@
String2SPInstructionType.put( "tack+*" , SPType.AggregateTernary);
// Neural network operators
- String2SPInstructionType.put( "conv2d", SPType.Dnn);
+ String2SPInstructionType.put( "conv2d", SPType.Dnn);
String2SPInstructionType.put( "conv2d_bias_add", SPType.Dnn);
- String2SPInstructionType.put( "maxpooling", SPType.Dnn);
- String2SPInstructionType.put( "relu_maxpooling", SPType.Dnn);
+ String2SPInstructionType.put( "maxpooling", SPType.Dnn);
+ String2SPInstructionType.put( "relu_maxpooling", SPType.Dnn);
String2SPInstructionType.put( RightIndex.OPCODE, SPType.MatrixIndexing);
- String2SPInstructionType.put( LeftIndex.OPCODE, SPType.MatrixIndexing);
- String2SPInstructionType.put( "mapLeftIndex" , SPType.MatrixIndexing);
+ String2SPInstructionType.put( LeftIndex.OPCODE, SPType.MatrixIndexing);
+ String2SPInstructionType.put( "mapLeftIndex", SPType.MatrixIndexing);
// Reorg Instruction Opcodes (repositioning of existing values)
- String2SPInstructionType.put( "r'" , SPType.Reorg);
- String2SPInstructionType.put( "rev" , SPType.Reorg);
- String2SPInstructionType.put( "rdiag" , SPType.Reorg);
- String2SPInstructionType.put( "rshape" , SPType.MatrixReshape);
- String2SPInstructionType.put( "rsort" , SPType.Reorg);
+ String2SPInstructionType.put( "r'", SPType.Reorg);
+ String2SPInstructionType.put( "rev", SPType.Reorg);
+ String2SPInstructionType.put( "rdiag", SPType.Reorg);
+ String2SPInstructionType.put( "rshape", SPType.MatrixReshape);
+ String2SPInstructionType.put( "rsort", SPType.Reorg);
- String2SPInstructionType.put( "+" , SPType.Binary);
- String2SPInstructionType.put( "-" , SPType.Binary);
- String2SPInstructionType.put( "*" , SPType.Binary);
- String2SPInstructionType.put( "/" , SPType.Binary);
- String2SPInstructionType.put( "%%" , SPType.Binary);
- String2SPInstructionType.put( "%/%" , SPType.Binary);
- String2SPInstructionType.put( "1-*" , SPType.Binary);
- String2SPInstructionType.put( "^" , SPType.Binary);
- String2SPInstructionType.put( "^2" , SPType.Binary);
- String2SPInstructionType.put( "*2" , SPType.Binary);
- String2SPInstructionType.put( "map+" , SPType.Binary);
- String2SPInstructionType.put( "map-" , SPType.Binary);
- String2SPInstructionType.put( "map*" , SPType.Binary);
- String2SPInstructionType.put( "map/" , SPType.Binary);
- String2SPInstructionType.put( "map%%" , SPType.Binary);
- String2SPInstructionType.put( "map%/%" , SPType.Binary);
- String2SPInstructionType.put( "map1-*" , SPType.Binary);
- String2SPInstructionType.put( "map^" , SPType.Binary);
- String2SPInstructionType.put( "map+*" , SPType.Binary);
- String2SPInstructionType.put( "map-*" , SPType.Binary);
+ String2SPInstructionType.put( "+", SPType.Binary);
+ String2SPInstructionType.put( "-", SPType.Binary);
+ String2SPInstructionType.put( "*", SPType.Binary);
+ String2SPInstructionType.put( "/", SPType.Binary);
+ String2SPInstructionType.put( "%%", SPType.Binary);
+ String2SPInstructionType.put( "%/%", SPType.Binary);
+ String2SPInstructionType.put( "1-*", SPType.Binary);
+ String2SPInstructionType.put( "^", SPType.Binary);
+ String2SPInstructionType.put( "^2", SPType.Binary);
+ String2SPInstructionType.put( "*2", SPType.Binary);
+ String2SPInstructionType.put( "map+", SPType.Binary);
+ String2SPInstructionType.put( "map-", SPType.Binary);
+ String2SPInstructionType.put( "map*", SPType.Binary);
+ String2SPInstructionType.put( "map/", SPType.Binary);
+ String2SPInstructionType.put( "map%%", SPType.Binary);
+ String2SPInstructionType.put( "map%/%", SPType.Binary);
+ String2SPInstructionType.put( "map1-*", SPType.Binary);
+ String2SPInstructionType.put( "map^", SPType.Binary);
+ String2SPInstructionType.put( "map+*", SPType.Binary);
+ String2SPInstructionType.put( "map-*", SPType.Binary);
String2SPInstructionType.put( "dropInvalidType", SPType.Binary);
String2SPInstructionType.put( "mapdropInvalidLength", SPType.Binary);
// Relational Instruction Opcodes
@@ -250,6 +250,7 @@
String2SPInstructionType.put( "sprop", SPType.Unary);
String2SPInstructionType.put( "sigmoid", SPType.Unary);
String2SPInstructionType.put( "detectSchema", SPType.Unary);
+ String2SPInstructionType.put( "colnames", SPType.Unary);
String2SPInstructionType.put( "isna", SPType.Unary);
String2SPInstructionType.put( "isnan", SPType.Unary);
String2SPInstructionType.put( "isinf", SPType.Unary);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java
index 13af891..4cbf93c 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.instructions.cp;
import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -37,12 +38,19 @@
ec.releaseFrameInput(input1.getName());
ec.setFrameOutput(output.getName(), retBlock);
}
- else if(getOpcode().equals("detectSchema"))
- {
+ else if(getOpcode().equals("detectSchema")) {
FrameBlock inBlock = ec.getFrameInput(input1.getName());
FrameBlock retBlock = inBlock.detectSchemaFromRow(Lop.SAMPLE_FRACTION);
ec.releaseFrameInput(input1.getName());
ec.setFrameOutput(output.getName(), retBlock);
}
+ else if(getOpcode().equals("colnames")) {
+ FrameBlock inBlock = ec.getFrameInput(input1.getName());
+ FrameBlock retBlock = inBlock.getColumnNamesAsFrame();
+ ec.releaseFrameInput(input1.getName());
+ ec.setFrameOutput(output.getName(), retBlock);
+ }
+ else
+ throw new DMLScriptException("Opcode '" + getOpcode() + "' is not a valid UnaryFrameCPInstruction");
}
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java
index 6cd2785..d4bcf42 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java
@@ -23,7 +23,9 @@
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -37,7 +39,7 @@
super(SPInstruction.SPType.Unary, op, in, out, opcode, instr);
}
- public static UnaryFrameSPInstruction parseInstruction (String str ) {
+ public static UnaryFrameSPInstruction parseInstruction(String str) {
CPOperand in = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
String opcode = parseUnaryInstruction(str, in, out);
@@ -46,18 +48,36 @@
@Override
public void processInstruction(ExecutionContext ec) {
- SparkExecutionContext sec = (SparkExecutionContext)ec;
- //get input
- JavaPairRDD<Long, FrameBlock> in = sec.getFrameBinaryBlockRDDHandleForVariable(input1.getName() );
- JavaPairRDD<Long,FrameBlock> out = in.mapToPair(new DetectSchemaUsingRows());
+ SparkExecutionContext sec = (SparkExecutionContext) ec;
+ if(getOpcode().equals(OpOp1.DETECTSCHEMA.toString()))
+ detectSchema(sec);
+ else if(getOpcode().equals(OpOp1.COLNAMES.toString()))
+ columnNames(sec);
+ else
+ throw new DMLScriptException("Opcode '" + getOpcode() + "' is not a valid UnaryFrameSPInstruction");
+ }
+
+ private void columnNames(SparkExecutionContext sec) {
+ // get input
+ JavaPairRDD<Long, FrameBlock> in = sec.getFrameBinaryBlockRDDHandleForVariable(input1.getName());
+ // get the first row block (frames are only blocked rowwise) and get its column names
+ FrameBlock outFrame = in.lookup(1L).get(0).getColumnNamesAsFrame();
+ sec.setFrameOutput(output.getName(), outFrame);
+ }
+
+ public void detectSchema(SparkExecutionContext sec) {
+ // get input
+ JavaPairRDD<Long, FrameBlock> in = sec.getFrameBinaryBlockRDDHandleForVariable(input1.getName());
+ JavaPairRDD<Long, FrameBlock> out = in.mapToPair(new DetectSchemaUsingRows());
FrameBlock outFrame = out.values().reduce(new MergeFrame());
sec.setFrameOutput(output.getName(), outFrame);
}
private static class DetectSchemaUsingRows implements PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> {
private static final long serialVersionUID = 5850400295183766400L;
+
@Override
- public Tuple2<Long,FrameBlock> call(Tuple2<Long, FrameBlock> arg0) throws Exception {
+ public Tuple2<Long, FrameBlock> call(Tuple2<Long, FrameBlock> arg0) throws Exception {
FrameBlock resultBlock = new FrameBlock(arg0._2.detectSchemaFromRow(Lop.SAMPLE_FRACTION));
return new Tuple2<>(1L, resultBlock);
}
@@ -65,6 +85,7 @@
private static class MergeFrame implements Function2<FrameBlock, FrameBlock, FrameBlock> {
private static final long serialVersionUID = 942744896521069893L;
+
@Override
public FrameBlock call(FrameBlock arg0, FrameBlock arg1) throws Exception {
return new FrameBlock(FrameBlock.mergeSchema(arg0, arg1));
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index e473acd..7ae6b53 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -176,7 +176,14 @@
public String[] getColumnNames() {
return getColumnNames(true);
}
-
+
+
+ public FrameBlock getColumnNamesAsFrame() {
+ FrameBlock fb = new FrameBlock(getNumColumns(), ValueType.STRING);
+ fb.appendRow(getColumnNames());
+ return fb;
+ }
+
/**
* Returns the column names of the frame block. This method
* allocates default column names if required.
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 4e55248..7e63127 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -20,8 +20,6 @@
package org.apache.sysds.test;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.io.ByteArrayOutputStream;
@@ -197,13 +195,6 @@
private boolean isOutAndExpectedDeletionDisabled = false;
- private int iExpectedStdOutState = 0;
- private int iUnexpectedStdOutState = 0;
- // private PrintStream originalPrintStreamStd = null;
-
- private int iExpectedStdErrState = 0;
- // private PrintStream originalErrStreamStd = null;
-
private boolean outputBuffering = true;
// Timestamp before test start.
diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java
new file mode 100644
index 0000000..be00f61
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.frame;
+
+import java.util.Arrays;
+import java.util.Collection;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
+import org.apache.sysds.runtime.io.FrameWriter;
+import org.apache.sysds.runtime.io.FrameWriterFactory;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import edu.emory.mathcs.backport.java.util.Collections;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FrameColumnNamesTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "ColumnNames";
+ private final static String TEST_DIR = "functions/frame/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + FrameColumnNamesTest.class.getSimpleName() + "/";
+
+ private final static int _rows = 10000;
+ @Parameterized.Parameter()
+ public String[] _columnNames;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {{new String[] {"A", "B", "C"}}, {new String[] {"1", "2", "3"}},
+ {new String[] {"Hello", "hello", "Hello", "hi", "u", "w", "u"}},});
+ }
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"}));
+ }
+
+ @Test
+ public void testDetectSchemaDoubleCP() {
+ runGetColNamesTest(_columnNames, ExecType.CP);
+ }
+
+ @Test
+ public void testDetectSchemaDoubleSpark() {
+ runGetColNamesTest(_columnNames, ExecType.SPARK);
+ }
+
+ @SuppressWarnings("unchecked")
+ private void runGetColNamesTest(String[] columnNames, ExecType et) {
+ Types.ExecMode platformOld = setExecMode(et);
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ try {
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-args", input("A"), String.valueOf(_rows),
+ Integer.toString(columnNames.length), output("B")};
+
+ Types.ValueType[] schema = (Types.ValueType[]) Collections
+ .nCopies(columnNames.length, Types.ValueType.FP64).toArray(new Types.ValueType[0]);
+ FrameBlock frame1 = new FrameBlock(schema);
+ frame1.setColumnNames(columnNames);
+ FrameWriter writer = FrameWriterFactory.createFrameWriter(FileFormat.CSV,
+ new FileFormatPropertiesCSV(true, ",", false));
+
+ double[][] A = getRandomMatrix(_rows, schema.length, Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123);
+ TestUtils.initFrameData(frame1, A, schema, _rows);
+ writer.writeFrameToHDFS(frame1, input("A"), _rows, schema.length);
+
+ runTest(true, false, null, -1);
+ FrameBlock frame2 = readDMLFrameFromHDFS("B", FileFormat.BINARY);
+
+ // verify output schema
+ for(int i = 0; i < schema.length; i++) {
+ Assert
+ .assertEquals("Wrong result: " + columnNames[i] + ".", columnNames[i], frame2.get(0, i).toString());
+ }
+ }
+ catch(Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ finally {
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/frame/ColumnNames.dml b/src/test/scripts/functions/frame/ColumnNames.dml
new file mode 100644
index 0000000..319a03c
--- /dev/null
+++ b/src/test/scripts/functions/frame/ColumnNames.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1, rows=$2, cols=$3, data_type="frame", format="csv", header=TRUE);
+R = colnames(X);
+write(R, $4, format="binary");
\ No newline at end of file
diff --git a/src/test/scripts/functions/frame/TypeOf.dml b/src/test/scripts/functions/frame/TypeOf.dml
index 7394541..6e8b3bb 100644
--- a/src/test/scripts/functions/frame/TypeOf.dml
+++ b/src/test/scripts/functions/frame/TypeOf.dml
@@ -20,7 +20,7 @@
#-------------------------------------------------------------
X = read($1, rows=$2, cols=$3, data_type="frame", format="csv");
-R = typeOf(X);
+R = typeof(X);
print(toString(R))
write(R, $4, format="binary");
\ No newline at end of file