| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.test.functions.mlcontext; |
| |
| import static org.apache.sysds.api.mlcontext.ScriptFactory.dml; |
| |
| import java.io.File; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.List; |
| |
| import org.apache.spark.api.java.JavaRDD; |
| import org.apache.spark.rdd.RDD; |
| import org.apache.spark.sql.Dataset; |
| import org.apache.spark.sql.Row; |
| import org.apache.spark.sql.RowFactory; |
| import org.apache.spark.sql.types.DataTypes; |
| import org.apache.spark.sql.types.StructField; |
| import org.apache.spark.sql.types.StructType; |
| import org.junit.Assert; |
| import org.junit.BeforeClass; |
| import org.junit.Test; |
| import org.apache.sysds.api.mlcontext.FrameFormat; |
| import org.apache.sysds.api.mlcontext.FrameMetadata; |
| import org.apache.sysds.api.mlcontext.FrameSchema; |
| import org.apache.sysds.api.mlcontext.MLResults; |
| import org.apache.sysds.api.mlcontext.MatrixFormat; |
| import org.apache.sysds.api.mlcontext.MatrixMetadata; |
| import org.apache.sysds.api.mlcontext.Script; |
| import org.apache.sysds.api.mlcontext.MLContext.ExplainLevel; |
| import org.apache.sysds.common.Types.ValueType; |
| import org.apache.sysds.runtime.instructions.spark.utils.FrameRDDConverterUtils; |
| import org.apache.sysds.runtime.instructions.spark.utils.RDDConverterUtils; |
| import org.apache.sysds.test.functions.mlcontext.MLContextTest.CommaSeparatedValueStringToDoubleArrayRow; |
| |
| import scala.collection.Iterator; |
| |
| public class MLContextFrameTest extends MLContextTestBase { |
| |
| public static enum SCRIPT_TYPE { |
| DML |
| } |
| |
| public static enum IO_TYPE { |
| ANY, FILE, JAVA_RDD_STR_CSV, JAVA_RDD_STR_IJV, RDD_STR_CSV, RDD_STR_IJV, DATAFRAME |
| } |
| |
| private static String CSV_DELIM = ","; |
| |
| @BeforeClass |
| public static void setUpClass() { |
| MLContextTestBase.setUpClass(); |
| ml.setExplainLevel(ExplainLevel.RECOMPILE_HOPS); |
| } |
| |
| @Test |
| public void testFrameJavaRDD_CSV_DML() { |
| testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.ANY); |
| } |
| |
| @Test |
| public void testFrameJavaRDD_CSV_DML_OutJavaRddCSV() { |
| testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.JAVA_RDD_STR_CSV); |
| } |
| |
| @Test |
| public void testFrameJavaRDD_IJV_DML() { |
| testFrame(FrameFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.ANY); |
| } |
| |
| @Test |
| public void testFrameRDD_IJV_DML() { |
| testFrame(FrameFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.RDD_STR_IJV, IO_TYPE.ANY); |
| } |
| |
| @Test |
| public void testFrameJavaRDD_IJV_DML_OutRddCSV() { |
| testFrame(FrameFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.RDD_STR_CSV); |
| } |
| |
| @Test |
| public void testFrameFile_CSV_DML() { |
| testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.FILE, IO_TYPE.ANY); |
| } |
| |
| @Test |
| public void testFrameFile_IJV_DML() { |
| testFrame(FrameFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.FILE, IO_TYPE.ANY); |
| } |
| |
| @Test |
| public void testFrameDataFrame_CSV_DML() { |
| testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.DATAFRAME, IO_TYPE.ANY); |
| } |
| |
| @Test |
| public void testFrameDataFrameOutDataFrame_CSV_DML() { |
| testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.DATAFRAME, IO_TYPE.DATAFRAME); |
| } |
| |
| public void testFrame(FrameFormat format, SCRIPT_TYPE script_type, IO_TYPE inputType, IO_TYPE outputType) { |
| |
| System.out.println("MLContextTest - Frame JavaRDD<String> for format: " + format + " Script: " + script_type); |
| |
| List<String> listA = new ArrayList<>(); |
| List<String> listB = new ArrayList<>(); |
| FrameMetadata fmA = null, fmB = null; |
| Script script = null; |
| ValueType[] schemaA = { ValueType.INT64, ValueType.STRING, ValueType.FP64, ValueType.BOOLEAN }; |
| List<ValueType> lschemaA = Arrays.asList(schemaA); |
| FrameSchema fschemaA = new FrameSchema(lschemaA); |
| ValueType[] schemaB = { ValueType.STRING, ValueType.FP64, ValueType.BOOLEAN }; |
| List<ValueType> lschemaB = Arrays.asList(schemaB); |
| FrameSchema fschemaB = new FrameSchema(lschemaB); |
| |
| if (inputType != IO_TYPE.FILE) { |
| if (format == FrameFormat.CSV) { |
| listA.add("1,Str2,3.0,true"); |
| listA.add("4,Str5,6.0,false"); |
| listA.add("7,Str8,9.0,true"); |
| |
| listB.add("Str12,13.0,true"); |
| listB.add("Str25,26.0,false"); |
| |
| fmA = new FrameMetadata(FrameFormat.CSV, fschemaA, 3, 4); |
| fmB = new FrameMetadata(FrameFormat.CSV, fschemaB, 2, 3); |
| } else if (format == FrameFormat.IJV) { |
| listA.add("1 1 1"); |
| listA.add("1 2 Str2"); |
| listA.add("1 3 3.0"); |
| listA.add("1 4 true"); |
| listA.add("2 1 4"); |
| listA.add("2 2 Str5"); |
| listA.add("2 3 6.0"); |
| listA.add("2 4 false"); |
| listA.add("3 1 7"); |
| listA.add("3 2 Str8"); |
| listA.add("3 3 9.0"); |
| listA.add("3 4 true"); |
| |
| listB.add("1 1 Str12"); |
| listB.add("1 2 13.0"); |
| listB.add("1 3 true"); |
| listB.add("2 1 Str25"); |
| listB.add("2 2 26.0"); |
| listB.add("2 3 false"); |
| |
| fmA = new FrameMetadata(FrameFormat.IJV, fschemaA, 3, 4); |
| fmB = new FrameMetadata(FrameFormat.IJV, fschemaB, 2, 3); |
| } |
| JavaRDD<String> javaRDDA = sc.parallelize(listA); |
| JavaRDD<String> javaRDDB = sc.parallelize(listB); |
| |
| if (inputType == IO_TYPE.DATAFRAME) { |
| JavaRDD<Row> javaRddRowA = FrameRDDConverterUtils.csvToRowRDD(sc, javaRDDA, CSV_DELIM, schemaA); |
| JavaRDD<Row> javaRddRowB = FrameRDDConverterUtils.csvToRowRDD(sc, javaRDDB, CSV_DELIM, schemaB); |
| |
| // Create DataFrame |
| StructType dfSchemaA = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schemaA, false); |
| Dataset<Row> dataFrameA = spark.createDataFrame(javaRddRowA, dfSchemaA); |
| StructType dfSchemaB = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schemaB, false); |
| Dataset<Row> dataFrameB = spark.createDataFrame(javaRddRowB, dfSchemaB); |
| if (script_type == SCRIPT_TYPE.DML) |
| script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", dataFrameA, fmA).in("B", dataFrameB, fmB).out("A") |
| .out("C"); |
| } else { |
| if (inputType == IO_TYPE.JAVA_RDD_STR_CSV || inputType == IO_TYPE.JAVA_RDD_STR_IJV) { |
| if (script_type == SCRIPT_TYPE.DML) |
| script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", javaRDDA, fmA).in("B", javaRDDB, fmB).out("A") |
| .out("C"); |
| } else if (inputType == IO_TYPE.RDD_STR_CSV || inputType == IO_TYPE.RDD_STR_IJV) { |
| RDD<String> rddA = JavaRDD.toRDD(javaRDDA); |
| RDD<String> rddB = JavaRDD.toRDD(javaRDDB); |
| |
| if (script_type == SCRIPT_TYPE.DML) |
| script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", rddA, fmA).in("B", rddB, fmB).out("A") |
| .out("C"); |
| } |
| |
| } |
| |
| } else { // Input type is file |
| String fileA = null, fileB = null; |
| if (format == FrameFormat.CSV) { |
| fileA = baseDirectory + File.separator + "FrameA.csv"; |
| fileB = baseDirectory + File.separator + "FrameB.csv"; |
| } else if (format == FrameFormat.IJV) { |
| fileA = baseDirectory + File.separator + "FrameA.ijv"; |
| fileB = baseDirectory + File.separator + "FrameB.ijv"; |
| } |
| |
| if (script_type == SCRIPT_TYPE.DML) |
| script = dml("A=read($A); B=read($B);A[2:3,2:4]=B;C=A[2:3,2:3];A[1,1]=234").in("$A", fileA, fmA) |
| .in("$B", fileB, fmB).out("A").out("C"); |
| } |
| |
| MLResults mlResults = ml.execute(script); |
| |
| //Validate output schema |
| List<ValueType> lschemaOutA = Arrays.asList(mlResults.getFrameObject("A").getSchema()); |
| List<ValueType> lschemaOutC = Arrays.asList(mlResults.getFrameObject("C").getSchema()); |
| Assert.assertEquals(ValueType.INT64, lschemaOutA.get(0)); |
| Assert.assertEquals(ValueType.STRING, lschemaOutA.get(1)); |
| Assert.assertEquals(ValueType.FP64, lschemaOutA.get(2)); |
| Assert.assertEquals(ValueType.BOOLEAN, lschemaOutA.get(3)); |
| |
| Assert.assertEquals(ValueType.STRING, lschemaOutC.get(0)); |
| Assert.assertEquals(ValueType.FP64, lschemaOutC.get(1)); |
| |
| if (outputType == IO_TYPE.JAVA_RDD_STR_CSV) { |
| |
| JavaRDD<String> javaRDDStringCSVA = mlResults.getJavaRDDStringCSV("A"); |
| List<String> linesA = javaRDDStringCSVA.collect(); |
| Assert.assertEquals("1,Str2,3.0,true", linesA.get(0)); |
| Assert.assertEquals("4,Str12,13.0,true", linesA.get(1)); |
| Assert.assertEquals("7,Str25,26.0,false", linesA.get(2)); |
| |
| JavaRDD<String> javaRDDStringCSVC = mlResults.getJavaRDDStringCSV("C"); |
| List<String> linesC = javaRDDStringCSVC.collect(); |
| Assert.assertEquals("Str12,13.0", linesC.get(0)); |
| Assert.assertEquals("Str25,26.0", linesC.get(1)); |
| } else if (outputType == IO_TYPE.JAVA_RDD_STR_IJV) { |
| JavaRDD<String> javaRDDStringIJVA = mlResults.getJavaRDDStringIJV("A"); |
| List<String> linesA = javaRDDStringIJVA.collect(); |
| Assert.assertEquals("1 1 1", linesA.get(0)); |
| Assert.assertEquals("1 2 Str2", linesA.get(1)); |
| Assert.assertEquals("1 3 3.0", linesA.get(2)); |
| Assert.assertEquals("1 4 true", linesA.get(3)); |
| Assert.assertEquals("2 1 4", linesA.get(4)); |
| Assert.assertEquals("2 2 Str12", linesA.get(5)); |
| Assert.assertEquals("2 3 13.0", linesA.get(6)); |
| Assert.assertEquals("2 4 true", linesA.get(7)); |
| |
| JavaRDD<String> javaRDDStringIJVC = mlResults.getJavaRDDStringIJV("C"); |
| List<String> linesC = javaRDDStringIJVC.collect(); |
| Assert.assertEquals("1 1 Str12", linesC.get(0)); |
| Assert.assertEquals("1 2 13.0", linesC.get(1)); |
| Assert.assertEquals("2 1 Str25", linesC.get(2)); |
| Assert.assertEquals("2 2 26.0", linesC.get(3)); |
| } else if (outputType == IO_TYPE.RDD_STR_CSV) { |
| RDD<String> rddStringCSVA = mlResults.getRDDStringCSV("A"); |
| Iterator<String> iteratorA = rddStringCSVA.toLocalIterator(); |
| Assert.assertEquals("1,Str2,3.0,true", iteratorA.next()); |
| Assert.assertEquals("4,Str12,13.0,true", iteratorA.next()); |
| Assert.assertEquals("7,Str25,26.0,false", iteratorA.next()); |
| |
| RDD<String> rddStringCSVC = mlResults.getRDDStringCSV("C"); |
| Iterator<String> iteratorC = rddStringCSVC.toLocalIterator(); |
| Assert.assertEquals("Str12,13.0", iteratorC.next()); |
| Assert.assertEquals("Str25,26.0", iteratorC.next()); |
| } else if (outputType == IO_TYPE.RDD_STR_IJV) { |
| RDD<String> rddStringIJVA = mlResults.getRDDStringIJV("A"); |
| Iterator<String> iteratorA = rddStringIJVA.toLocalIterator(); |
| Assert.assertEquals("1 1 1", iteratorA.next()); |
| Assert.assertEquals("1 2 Str2", iteratorA.next()); |
| Assert.assertEquals("1 3 3.0", iteratorA.next()); |
| Assert.assertEquals("1 4 true", iteratorA.next()); |
| Assert.assertEquals("2 1 4", iteratorA.next()); |
| Assert.assertEquals("2 2 Str12", iteratorA.next()); |
| Assert.assertEquals("2 3 13.0", iteratorA.next()); |
| Assert.assertEquals("2 4 true", iteratorA.next()); |
| Assert.assertEquals("3 1 7", iteratorA.next()); |
| Assert.assertEquals("3 2 Str25", iteratorA.next()); |
| Assert.assertEquals("3 3 26.0", iteratorA.next()); |
| Assert.assertEquals("3 4 false", iteratorA.next()); |
| |
| RDD<String> rddStringIJVC = mlResults.getRDDStringIJV("C"); |
| Iterator<String> iteratorC = rddStringIJVC.toLocalIterator(); |
| Assert.assertEquals("1 1 Str12", iteratorC.next()); |
| Assert.assertEquals("1 2 13.0", iteratorC.next()); |
| Assert.assertEquals("2 1 Str25", iteratorC.next()); |
| Assert.assertEquals("2 2 26.0", iteratorC.next()); |
| |
| } else if (outputType == IO_TYPE.DATAFRAME) { |
| |
| Dataset<Row> dataFrameA = mlResults.getDataFrame("A").drop(RDDConverterUtils.DF_ID_COLUMN); |
| StructType dfschemaA = dataFrameA.schema(); |
| StructField structTypeA = dfschemaA.apply(0); |
| Assert.assertEquals(DataTypes.LongType, structTypeA.dataType()); |
| structTypeA = dfschemaA.apply(1); |
| Assert.assertEquals(DataTypes.StringType, structTypeA.dataType()); |
| structTypeA = dfschemaA.apply(2); |
| Assert.assertEquals(DataTypes.DoubleType, structTypeA.dataType()); |
| structTypeA = dfschemaA.apply(3); |
| Assert.assertEquals(DataTypes.BooleanType, structTypeA.dataType()); |
| |
| List<Row> listAOut = dataFrameA.collectAsList(); |
| |
| Row row1 = listAOut.get(0); |
| Assert.assertEquals("Mismatch with expected value", Long.valueOf(1), row1.get(0)); |
| Assert.assertEquals("Mismatch with expected value", "Str2", row1.get(1)); |
| Assert.assertEquals("Mismatch with expected value", 3.0, row1.get(2)); |
| Assert.assertEquals("Mismatch with expected value", true, row1.get(3)); |
| |
| Row row2 = listAOut.get(1); |
| Assert.assertEquals("Mismatch with expected value", Long.valueOf(4), row2.get(0)); |
| Assert.assertEquals("Mismatch with expected value", "Str12", row2.get(1)); |
| Assert.assertEquals("Mismatch with expected value", 13.0, row2.get(2)); |
| Assert.assertEquals("Mismatch with expected value", true, row2.get(3)); |
| |
| Dataset<Row> dataFrameC = mlResults.getDataFrame("C").drop(RDDConverterUtils.DF_ID_COLUMN); |
| StructType dfschemaC = dataFrameC.schema(); |
| StructField structTypeC = dfschemaC.apply(0); |
| Assert.assertEquals(DataTypes.StringType, structTypeC.dataType()); |
| structTypeC = dfschemaC.apply(1); |
| Assert.assertEquals(DataTypes.DoubleType, structTypeC.dataType()); |
| |
| List<Row> listCOut = dataFrameC.collectAsList(); |
| |
| Row row3 = listCOut.get(0); |
| Assert.assertEquals("Mismatch with expected value", "Str12", row3.get(0)); |
| Assert.assertEquals("Mismatch with expected value", 13.0, row3.get(1)); |
| |
| Row row4 = listCOut.get(1); |
| Assert.assertEquals("Mismatch with expected value", "Str25", row4.get(0)); |
| Assert.assertEquals("Mismatch with expected value", 26.0, row4.get(1)); |
| } else { |
| String[][] frameA = mlResults.getFrameAs2DStringArray("A"); |
| Assert.assertEquals("Str2", frameA[0][1]); |
| Assert.assertEquals("3.0", frameA[0][2]); |
| Assert.assertEquals("13.0", frameA[1][2]); |
| Assert.assertEquals("true", frameA[1][3]); |
| Assert.assertEquals("Str25", frameA[2][1]); |
| |
| String[][] frameC = mlResults.getFrameAs2DStringArray("C"); |
| Assert.assertEquals("Str12", frameC[0][0]); |
| Assert.assertEquals("Str25", frameC[1][0]); |
| Assert.assertEquals("13.0", frameC[0][1]); |
| Assert.assertEquals("26.0", frameC[1][1]); |
| } |
| } |
| |
| @Test |
| public void testOutputFrameDML() { |
| System.out.println("MLContextFrameTest - output frame DML"); |
| |
| String s = "M = read($Min, data_type='frame', format='csv');"; |
| String csvFile = baseDirectory + File.separator + "one-two-three-four.csv"; |
| Script script = dml(s).in("$Min", csvFile).out("M"); |
| String[][] frame = ml.execute(script).getFrameAs2DStringArray("M"); |
| Assert.assertEquals("one", frame[0][0]); |
| Assert.assertEquals("two", frame[0][1]); |
| Assert.assertEquals("three", frame[1][0]); |
| Assert.assertEquals("four", frame[1][1]); |
| } |
| |
| @Test |
| public void testInputFrameAndMatrixOutputMatrix() { |
| System.out.println("MLContextFrameTest - input frame and matrix, output matrix"); |
| |
| List<String> dataA = new ArrayList<>(); |
| dataA.add("Test1,4.0"); |
| dataA.add("Test2,5.0"); |
| dataA.add("Test3,6.0"); |
| JavaRDD<String> javaRddStringA = sc.parallelize(dataA); |
| ValueType[] schema = { ValueType.STRING, ValueType.FP64 }; |
| |
| List<String> dataB = new ArrayList<>(); |
| dataB.add("1.0"); |
| dataB.add("2.0"); |
| JavaRDD<String> javaRddStringB = sc.parallelize(dataB); |
| |
| JavaRDD<Row> javaRddRowA = FrameRDDConverterUtils.csvToRowRDD(sc, javaRddStringA, CSV_DELIM, schema); |
| JavaRDD<Row> javaRddRowB = javaRddStringB.map(new CommaSeparatedValueStringToDoubleArrayRow()); |
| |
| List<StructField> fieldsA = new ArrayList<>(); |
| fieldsA.add(DataTypes.createStructField("1", DataTypes.StringType, true)); |
| fieldsA.add(DataTypes.createStructField("2", DataTypes.DoubleType, true)); |
| StructType schemaA = DataTypes.createStructType(fieldsA); |
| Dataset<Row> dataFrameA = spark.createDataFrame(javaRddRowA, schemaA); |
| |
| List<StructField> fieldsB = new ArrayList<>(); |
| fieldsB.add(DataTypes.createStructField("1", DataTypes.DoubleType, true)); |
| StructType schemaB = DataTypes.createStructType(fieldsB); |
| Dataset<Row> dataFrameB = spark.createDataFrame(javaRddRowB, schemaB); |
| |
| String dmlString = "[tA, tAM] = transformencode (target = A, spec = \"{ids: true ,recode: [ 1, 2 ]}\");\n" |
| + "C = tA %*% B;\n" + "M = s * C;"; |
| |
| Script script = dml(dmlString) |
| .in("A", dataFrameA, |
| new FrameMetadata(FrameFormat.CSV, dataFrameA.count(), (long) dataFrameA.columns().length)) |
| .in("B", dataFrameB, |
| new MatrixMetadata(MatrixFormat.CSV, dataFrameB.count(), (long) dataFrameB.columns().length)) |
| .in("s", 2).out("M"); |
| MLResults results = ml.execute(script); |
| double[][] matrix = results.getMatrixAs2DDoubleArray("M"); |
| Assert.assertEquals(6.0, matrix[0][0], 0.0); |
| Assert.assertEquals(12.0, matrix[1][0], 0.0); |
| Assert.assertEquals(18.0, matrix[2][0], 0.0); |
| } |
| |
| @Test |
| public void testInputFrameAndMatrixOutputMatrixAndFrame() { |
| System.out.println("MLContextFrameTest - input frame and matrix, output matrix and frame"); |
| |
| Row[] rowsA = {RowFactory.create("Doc1", "Feat1", 10), RowFactory.create("Doc1", "Feat2", 20), RowFactory.create("Doc2", "Feat1", 31)}; |
| |
| JavaRDD<Row> javaRddRowA = sc. parallelize( Arrays.asList(rowsA)); |
| |
| List<StructField> fieldsA = new ArrayList<>(); |
| fieldsA.add(DataTypes.createStructField("myID", DataTypes.StringType, true)); |
| fieldsA.add(DataTypes.createStructField("FeatureName", DataTypes.StringType, true)); |
| fieldsA.add(DataTypes.createStructField("FeatureValue", DataTypes.IntegerType, true)); |
| StructType schemaA = DataTypes.createStructType(fieldsA); |
| Dataset<Row> dataFrameA = spark.createDataFrame(javaRddRowA, schemaA); |
| |
| String dmlString = "[tA, tAM] = transformencode (target = A, spec = \"{ids: false ,recode: [ myID, FeatureName ]}\");"; |
| |
| Script script = dml(dmlString) |
| .in("A", dataFrameA, |
| new FrameMetadata(FrameFormat.CSV, dataFrameA.count(), (long) dataFrameA.columns().length)) |
| .out("tA").out("tAM"); |
| MLResults results = ml.execute(script); |
| |
| double[][] matrixtA = results.getMatrixAs2DDoubleArray("tA"); |
| Assert.assertEquals(10.0, matrixtA[0][2], 0.0); |
| Assert.assertEquals(20.0, matrixtA[1][2], 0.0); |
| Assert.assertEquals(31.0, matrixtA[2][2], 0.0); |
| |
| Dataset<Row> dataFrame_tA = results.getMatrix("tA").toDF(); |
| System.out.println("Number of matrix tA rows = " + dataFrame_tA.count()); |
| dataFrame_tA.printSchema(); |
| dataFrame_tA.show(); |
| |
| Dataset<Row> dataFrame_tAM = results.getFrame("tAM").toDF(); |
| System.out.println("Number of frame tAM rows = " + dataFrame_tAM.count()); |
| dataFrame_tAM.printSchema(); |
| dataFrame_tAM.show(); |
| } |
| |
| @Test |
| public void testTransform() { |
| System.out.println("MLContextFrameTest - transform"); |
| |
| Row[] rowsA = {RowFactory.create("\"`@(\"(!&",2,"20news-bydate-train/comp.os.ms-windows.misc/9979"), |
| RowFactory.create("\"`@(\"\"(!&\"",3,"20news-bydate-train/comp.os.ms-windows.misc/9979")}; |
| |
| JavaRDD<Row> javaRddRowA = sc. parallelize( Arrays.asList(rowsA)); |
| |
| List<StructField> fieldsA = new ArrayList<>(); |
| fieldsA.add(DataTypes.createStructField("featureName", DataTypes.StringType, true)); |
| fieldsA.add(DataTypes.createStructField("featureValue", DataTypes.IntegerType, true)); |
| fieldsA.add(DataTypes.createStructField("id", DataTypes.StringType, true)); |
| StructType schemaA = DataTypes.createStructType(fieldsA); |
| Dataset<Row> dataFrameA = spark.createDataFrame(javaRddRowA, schemaA); |
| |
| String dmlString = "[tA, tAM] = transformencode (target = A, spec = \"{ids: false ,recode: [ featureName, id ]}\");"; |
| |
| Script script = dml(dmlString) |
| .in("A", dataFrameA, |
| new FrameMetadata(FrameFormat.CSV, dataFrameA.count(), (long) dataFrameA.columns().length)) |
| .out("tA").out("tAM"); |
| ml.setExplain(true); |
| ml.setExplainLevel(ExplainLevel.RECOMPILE_HOPS); |
| MLResults results = ml.execute(script); |
| |
| double[][] matrixtA = results.getMatrixAs2DDoubleArray("tA"); |
| Assert.assertEquals(1.0, matrixtA[0][2], 0.0); |
| |
| Dataset<Row> dataFrame_tA = results.getMatrix("tA").toDF(); |
| System.out.println("Number of matrix tA rows = " + dataFrame_tA.count()); |
| dataFrame_tA.printSchema(); |
| dataFrame_tA.show(); |
| |
| Dataset<Row> dataFrame_tAM = results.getFrame("tAM").toDF(); |
| System.out.println("Number of frame tAM rows = " + dataFrame_tAM.count()); |
| dataFrame_tAM.printSchema(); |
| dataFrame_tAM.show(); |
| } |
| |
| // NOTE: the ordering of the frame values seem to come out differently here |
| // than in the scala shell, |
| // so this should be investigated or explained. |
| // @Test |
| // public void testInputFrameOutputMatrixAndFrame() { |
| // System.out.println("MLContextFrameTest - input frame, output matrix and |
| // frame"); |
| // |
| // List<String> dataA = new ArrayList<String>(); |
| // dataA.add("Test1,Test4"); |
| // dataA.add("Test2,Test5"); |
| // dataA.add("Test3,Test6"); |
| // JavaRDD<String> javaRddStringA = sc.parallelize(dataA); |
| // |
| // JavaRDD<Row> javaRddRowA = javaRddStringA.map(new |
| // CommaSeparatedValueStringToRow()); |
| // |
| // List<StructField> fieldsA = new ArrayList<StructField>(); |
| // fieldsA.add(DataTypes.createStructField("1", DataTypes.StringType, |
| // true)); |
| // fieldsA.add(DataTypes.createStructField("2", DataTypes.StringType, |
| // true)); |
| // StructType schemaA = DataTypes.createStructType(fieldsA); |
| // DataFrame dataFrameA = spark.createDataFrame(javaRddRowA, schemaA); |
| // |
| // String dmlString = "[tA, tAM] = transformencode (target = A, spec = |
| // \"{ids: true ,recode: [ 1, 2 ]}\");\n"; |
| // |
| // Script script = dml(dmlString) |
| // .in("A", dataFrameA, |
| // new FrameMetadata(FrameFormat.CSV, dataFrameA.count(), (long) |
| // dataFrameA.columns().length)) |
| // .out("tA", "tAM"); |
| // MLResults results = ml.execute(script); |
| // double[][] matrix = results.getMatrixAs2DDoubleArray("tA"); |
| // Assert.assertEquals(1.0, matrix[0][0], 0.0); |
| // Assert.assertEquals(1.0, matrix[0][1], 0.0); |
| // Assert.assertEquals(2.0, matrix[1][0], 0.0); |
| // Assert.assertEquals(2.0, matrix[1][1], 0.0); |
| // Assert.assertEquals(3.0, matrix[2][0], 0.0); |
| // Assert.assertEquals(3.0, matrix[2][1], 0.0); |
| // |
| // TODO: Add asserts for frame if ordering is as expected |
| // String[][] frame = results.getFrameAs2DStringArray("tAM"); |
| // for (int i = 0; i < frame.length; i++) { |
| // for (int j = 0; j < frame[i].length; j++) { |
| // System.out.println("[" + i + "][" + j + "]:" + frame[i][j]); |
| // } |
| // } |
| // } |
| |
| } |