blob: 67d5626642314cdf2fd9d21bb4c9903995dc48a3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.test.functions.frame;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.runtime.io.FrameWriter;
import org.apache.sysds.runtime.io.FrameWriterFactory;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import java.security.SecureRandom;
public class DetectSchemaTest extends AutomatedTestBase {
private final static String TEST_NAME = "DetectSchema";
private final static String TEST_DIR = "functions/frame/";
private static final String TEST_CLASS_DIR = TEST_DIR + DetectSchemaTest.class.getSimpleName() + "/";
private final static int rows = 10000;
private final static Types.ValueType[] schemaStrings = {Types.ValueType.INT32, Types.ValueType.BOOLEAN, Types.ValueType.FP32, Types.ValueType.STRING, Types.ValueType.STRING, Types.ValueType.FP32};
private final static Types.ValueType[] schemaDoubles = new Types.ValueType[]{Types.ValueType.FP64, Types.ValueType.FP64};
private final static Types.ValueType[] schemaMixed = new Types.ValueType[]{Types.ValueType.INT64, Types.ValueType.FP64, Types.ValueType.INT64, Types.ValueType.BOOLEAN};
@BeforeClass
public static void init() {
TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
}
@AfterClass
public static void cleanUp() {
if (TEST_CACHE_ENABLED) {
TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
}
}
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B"}));
if (TEST_CACHE_ENABLED) {
setOutAndExpectedDeletionDisabled(true);
}
}
@Test
public void testDetectSchemaDoubleCP() {
runDetectSchemaTest(schemaDoubles, rows, schemaDoubles.length, false, ExecType.CP);
}
@Test
public void testDetectSchemaDoubleSpark() {
runDetectSchemaTest(schemaDoubles, rows, schemaDoubles.length, false, ExecType.SPARK);
}
@Test
public void testDetectSchemaStringCP() {
runDetectSchemaTest(schemaStrings, rows, schemaStrings.length, true, ExecType.CP);
}
@Test
public void testDetectSchemaStringSpark() {
runDetectSchemaTest(schemaStrings, rows, schemaStrings.length, true, ExecType.SPARK);
}
@Test
public void testDetectSchemaMixCP() {
runDetectSchemaTest(schemaMixed, rows, schemaMixed.length, false, ExecType.CP);
}
@Test
public void testDetectSchemaMixSpark() {
runDetectSchemaTest(schemaMixed, rows, schemaMixed.length, false, ExecType.SPARK);
}
private void runDetectSchemaTest(Types.ValueType[] schema, int rows, int cols, boolean isStringTest, ExecType et) {
Types.ExecMode platformOld = setExecMode(et);
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
try {
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[]{"-explain", "recompile_runtime", "-args", input("A"), String.valueOf(rows), Integer.toString(cols), output("B")};
FrameBlock frame1 = new FrameBlock(schema);
FrameWriter writer = FrameWriterFactory.createFrameWriter(FileFormat.CSV);
if (!isStringTest) {
double[][] A = getRandomMatrix(rows, schema.length, -Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 2373);
initFrameDataDouble(frame1, A, schema);
writer.writeFrameToHDFS(frame1, input("A"), rows, schema.length);
}
else {
double[][] A = getRandomMatrix(rows, 3, -Float.MAX_VALUE, Float.MAX_VALUE, 0.7, 2373);
initFrameDataString(frame1, A, schema);
writer.writeFrameToHDFS(frame1.slice(0, rows-1, 0, schema.length-1, new FrameBlock()), input("A"), rows, schema.length);
schema[schema.length-2] = Types.ValueType.FP64;
}
runTest(true, false, null, -1);
FrameBlock frame2 = readDMLFrameFromHDFS("B", FileFormat.BINARY);
//verify output schema
for (int i = 0; i < schema.length; i++) {
Assert.assertEquals("Wrong result: " + frame2.getSchema()[i] + ".",
schema[i].toString(), frame2.get(0, i).toString());
}
}
catch (Exception ex) {
throw new RuntimeException(ex);
}
finally {
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true;
OptimizerUtils.ALLOW_OPERATOR_FUSION = true;
}
}
private static void initFrameDataString(FrameBlock frame1, double[][] data, Types.ValueType[] lschema) {
for (int j = 0; j < 3; j++) {
Types.ValueType vt = lschema[j];
switch (vt) {
case STRING:
String[] tmp1 = new String[rows];
for (int i = 0; i < rows; i++)
tmp1[i] = (String) UtilFunctions.doubleToObject(vt, data[i][j]);
frame1.appendColumn(tmp1);
break;
case BOOLEAN:
boolean[] tmp2 = new boolean[rows];
for (int i = 0; i < rows; i++)
data[i][j] = (tmp2[i] = (Boolean) UtilFunctions.doubleToObject(vt, data[i][j], false)) ? 1 : 0;
frame1.appendColumn(tmp2);
break;
case INT32:
int[] tmp3 = new int[rows];
for (int i = 0; i < rows; i++)
data[i][j] = tmp3[i] = (Integer) UtilFunctions.doubleToObject(Types.ValueType.INT32, data[i][j], false);
frame1.appendColumn(tmp3);
break;
case INT64:
long[] tmp4 = new long[rows];
for (int i = 0; i < rows; i++)
data[i][j] = tmp4[i] = (Long) UtilFunctions.doubleToObject(Types.ValueType.INT64, data[i][j], false);
frame1.appendColumn(tmp4);
break;
case FP32:
double[] tmp5 = new double[rows];
for (int i = 0; i < rows; i++)
tmp5[i] = (Float) UtilFunctions.doubleToObject(vt, data[i][j], false);
frame1.appendColumn(tmp5);
break;
case FP64:
double[] tmp6 = new double[rows];
for (int i = 0; i < rows; i++)
tmp6[i] = (Double) UtilFunctions.doubleToObject(vt, data[i][j], false);
frame1.appendColumn(tmp6);
break;
default:
throw new RuntimeException("Unsupported value type: " + vt);
}
}
String[] randomData = generateRandomString(8, rows);
frame1.appendColumn(randomData);
frame1.appendColumn(doubleSpecialData(rows));
frame1.appendColumn(floatLimitData(rows));
}
private static void initFrameDataDouble(FrameBlock frame, double[][] data, Types.ValueType[] lschema) {
Object[] row1 = new Object[lschema.length];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < lschema.length; j++) {
data[i][j] = UtilFunctions.objectToDouble(lschema[j],
row1[j] = UtilFunctions.doubleToObject(lschema[j], data[i][j]));
}
frame.appendRow(row1);
}
}
private static String[] generateRandomString(int stringLength, int rows) {
String CHAR_LOWER = "abcdefghijklmnopqrstuvwxyz";
String CHAR_UPPER = CHAR_LOWER.toUpperCase();
String NUMBER = "0123456789";
String DATA_FOR_RANDOM_STRING = CHAR_LOWER + CHAR_UPPER + NUMBER;
String[] A = new String[rows];
SecureRandom random = new SecureRandom();
if (stringLength < 1) throw new IllegalArgumentException();
for (int j = 0; j < rows; j++) {
StringBuilder sb = new StringBuilder(stringLength);
for (int i = 0; i < stringLength; i++) {
int rndCharAt = random.nextInt(DATA_FOR_RANDOM_STRING.length());
char rndChar = DATA_FOR_RANDOM_STRING.charAt(rndCharAt);
sb.append(rndChar);
}
A[j] = sb.toString();
}
return A;
}
private static String[] doubleSpecialData(int rows) {
String[] dataArray = new String[]{"Infinity", "3.4028234e+38", "Nan" , "-3.4028236e+38" };
String[] A = new String[rows];
SecureRandom random = new SecureRandom();
for (int j = 0; j < rows; j++)
A[j] = dataArray[random.nextInt(4)];
return A;
}
private static double[] floatLimitData(int rows) {
double[] dataArray = new double[]{Float.MAX_VALUE, 3.4028233E38, 3.4028234e38 , 3.4028228e38, 2.4028228e38, -3.4028234e38, -3.40282310e38};
double[] A = new double[rows];
SecureRandom random = new SecureRandom();
for (int j = 0; j < rows; j++)
A[j] = dataArray[random.nextInt(7)];
return A;
}
}