blob: 8ab52d9f64000dcb8c148251e17dfe602c1d8c4a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.test.functions.transform;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Ignore;
import org.junit.Test;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
public class TransformFrameEncodeWordEmbedding2Test extends AutomatedTestBase
{
private final static String TEST_NAME1 = "TransformFrameEncodeWordEmbeddings2";
private final static String TEST_NAME2 = "TransformFrameEncodeWordEmbeddings2MultiCols1";
private final static String TEST_NAME3 = "TransformFrameEncodeWordEmbeddings2MultiCols2";
private final static String TEST_DIR = "functions/transform/";
private final static String TEST_CLASS_DIR = TEST_DIR + TransformFrameEncodeWordEmbeddingTest.class.getSimpleName() + "/";
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_DIR, TEST_NAME1));
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_DIR, TEST_NAME2));
addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_DIR, TEST_NAME3));
}
@Test
public void testTransformToWordEmbeddings() {
runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE);
}
@Test
@Ignore
public void testNonRandomTransformToWordEmbeddings2Cols() {
runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE);
}
@Test
@Ignore
public void testRandomTransformToWordEmbeddings4Cols() {
runTransformTestMultiCols(TEST_NAME3, ExecMode.SINGLE_NODE);
}
private void runTransformTest(String testname, ExecMode rt)
{
//set runtime platform
ExecMode rtold = setExecMode(rt);
try
{
int rows = 100;
int cols = 100;
getAndLoadTestConfiguration(testname);
fullDMLScriptName = getScript();
// Generate random embeddings for the distinct tokens
double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10, 1, new Date().getTime());
// Generate random distinct tokens
List<String> strings = generateRandomStrings(rows, 10);
// Generate the dictionary by assigning unique ID to each distinct token
Map<String,Integer> map = writeDictToCsvFile(strings, baseDirectory + INPUT_DIR + "dict");
// Create the dataset by repeating and shuffling the distinct tokens
List<String> stringsColumn = shuffleAndMultiplyStrings(strings, 320);
writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + "data");
//run script
programArgs = new String[]{"-stats","-args", input("embeddings"), input("data"), input("dict"), output("result")};
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
// Manually derive the expected result
double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a, map, stringsColumn);
// Compare results
HashMap<MatrixValue.CellIndex, Double> res_actual = readDMLMatrixFromOutputDir("result");
double[][] resultActualDouble = TestUtils.convertHashMapToDoubleArray(res_actual);
//System.out.println("Actual Result [" + resultActualDouble.length + "x" + resultActualDouble[0].length + "]:");
//print2DimDoubleArray(resultActualDouble);
//System.out.println("\nExpected Result [" + res_expected.length + "x" + res_expected[0].length + "]:");
//print2DimDoubleArray(res_expected);
TestUtils.compareMatrices(resultActualDouble, res_expected, 1e-6);
}
catch(Exception ex) {
throw new RuntimeException(ex);
}
finally {
resetExecMode(rtold);
}
}
private void print2DimDoubleArray(double[][] resultActualDouble) {
Arrays.stream(resultActualDouble).forEach(
e -> System.out.println(Arrays.stream(e).mapToObj(d -> String.format("%06.1f", d))
.reduce("", (sub, elem) -> sub + " " + elem)));
}
private void runTransformTestMultiCols(String testname, ExecMode rt)
{
//set runtime platform
ExecMode rtold = setExecMode(rt);
try
{
int rows = 100;
int cols = 100;
getAndLoadTestConfiguration(testname);
fullDMLScriptName = getScript();
// Generate random embeddings for the distinct tokens
double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10, 1, new Date().getTime());
// Generate random distinct tokens
List<String> strings = generateRandomStrings(rows, 10);
// Generate the dictionary by assigning unique ID to each distinct token
Map<String,Integer> map = writeDictToCsvFile(strings, baseDirectory + INPUT_DIR + "dict");
// Create the dataset by repeating and shuffling the distinct tokens
List<String> stringsColumn = shuffleAndMultiplyStrings(strings, 10);
writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + "data");
//run script
programArgs = new String[]{"-stats","-args", input("embeddings"), input("data"), input("dict"), output("result"), output("result2")};
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
// Manually derive the expected result
double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a, map, stringsColumn);
// Compare results
HashMap<MatrixValue.CellIndex, Double> res_actual = readDMLMatrixFromOutputDir("result");
HashMap<MatrixValue.CellIndex, Double> res_actual2 = readDMLMatrixFromOutputDir("result2");
double[][] resultActualDouble = TestUtils.convertHashMapToDoubleArray(res_actual);
double[][] resultActualDouble2 = TestUtils.convertHashMapToDoubleArray(res_actual2);
//System.out.println("Actual Result1 [" + resultActualDouble.length + "x" + resultActualDouble[0].length + "]:");
///print2DimDoubleArray(resultActualDouble);
//System.out.println("\nActual Result2 [" + resultActualDouble.length + "x" + resultActualDouble[0].length + "]:");
//print2DimDoubleArray(resultActualDouble2);
//System.out.println("\nExpected Result [" + res_expected.length + "x" + res_expected[0].length + "]:");
//print2DimDoubleArray(res_expected);
TestUtils.compareMatrices(resultActualDouble, res_expected, 1e-6);
TestUtils.compareMatrices(resultActualDouble, resultActualDouble2, 1e-6);
}
catch(Exception ex) {
throw new RuntimeException(ex);
}
finally {
resetExecMode(rtold);
}
}
private double[][] manuallyDeriveWordEmbeddings(int cols, double[][] a, Map<String, Integer> map, List<String> stringsColumn) {
// Manually derive the expected result
double[][] res_expected = new double[stringsColumn.size()][cols];
for (int i = 0; i < stringsColumn.size(); i++) {
int rowMapped = map.get(stringsColumn.get(i));
System.arraycopy(a[rowMapped], 0, res_expected[i], 0, cols);
}
return res_expected;
}
private double[][] generateWordEmbeddings(int rows, int cols) {
double[][] a = new double[rows][cols];
for (int i = 0; i < a.length; i++) {
for (int j = 0; j < a[i].length; j++) {
a[i][j] = cols *i + j;
}
}
return a;
}
public static List<String> shuffleAndMultiplyStrings(List<String> strings, int multiply){
List<String> out = new ArrayList<>();
Random random = new Random();
for (int i = 0; i < strings.size()*multiply; i++) {
out.add(strings.get(random.nextInt(strings.size())));
}
return out;
}
public static List<String> generateRandomStrings(int numStrings, int stringLength) {
List<String> randomStrings = new ArrayList<>();
Random random = new Random();
String characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
for (int i = 0; i < numStrings; i++) {
randomStrings.add(generateRandomString(random, stringLength, characters));
}
return randomStrings;
}
public static String generateRandomString(Random random, int stringLength, String characters){
StringBuilder randomString = new StringBuilder();
for (int j = 0; j < stringLength; j++) {
int randomIndex = random.nextInt(characters.length());
randomString.append(characters.charAt(randomIndex));
}
return randomString.toString();
}
public static void writeStringsToCsvFile(List<String> strings, String fileName) {
try (BufferedWriter bw = new BufferedWriter(new FileWriter(fileName))) {
for (String line : strings) {
bw.write(line);
bw.newLine();
}
} catch (IOException e) {
e.printStackTrace();
}
}
public static Map<String,Integer> writeDictToCsvFile(List<String> strings, String fileName) {
try (BufferedWriter bw = new BufferedWriter(new FileWriter(fileName))) {
Map<String,Integer> map = new HashMap<>();
for (int i = 0; i < strings.size(); i++) {
map.put(strings.get(i), i);
bw.write(strings.get(i) + Lop.DATATYPE_PREFIX + (i+1) + "\n");
}
return map;
} catch (IOException e) {
e.printStackTrace();
return null;
}
}
}