blob: 4312c4a040204b977f6f2fa05ddf28a10e6aee81 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.test.component.frame;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestUtils;
public class FrameIndexingTest extends AutomatedTestBase
{
private final static int rows = 3345;
private final static int rl = 234;
private final static int ru = 1432;
private final static int cl = 0;
private final static int cu = 2;
private final static ValueType[] schemaStrings = new ValueType[]{ValueType.STRING, ValueType.STRING, ValueType.STRING};
private final static ValueType[] schemaMixed = new ValueType[]{ValueType.STRING, ValueType.FP64, ValueType.INT64, ValueType.BOOLEAN};
private enum IXType {
RIX,
LIX,
}
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
}
@Test
public void testFrameStringsRIX() {
runFrameIndexingTest(schemaStrings, IXType.RIX);
}
@Test
public void testFrameMixedRIX() {
runFrameIndexingTest(schemaMixed, IXType.RIX);
}
@Test
public void testFrameStringsLIX() {
runFrameIndexingTest(schemaStrings, IXType.LIX);
}
@Test
public void testFrameMixedLIX() {
runFrameIndexingTest(schemaMixed, IXType.LIX);
}
/**
*
* @param sparseM1
* @param sparseM2
* @param instType
*/
private void runFrameIndexingTest( ValueType[] schema, IXType itype)
{
try
{
//data generation
double[][] A = getRandomMatrix(rows, schema.length, -10, 10, 0.9, 2412);
//init data frame 1
FrameBlock frame1 = new FrameBlock(schema);
Object[] row1 = new Object[schema.length];
for( int i=0; i<rows; i++ ) {
for( int j=0; j<schema.length; j++ )
A[i][j] = UtilFunctions.objectToDouble(schema[j],
row1[j] = UtilFunctions.doubleToObject(schema[j], A[i][j]));
frame1.appendRow(row1);
}
//core indexing operation
MatrixBlock mbC = null;
FrameBlock frame3 = null;
if( itype == IXType.RIX )
{
//matrix indexing
MatrixBlock mbA = DataConverter.convertToMatrixBlock(A);
mbC = mbA.slice(rl, ru, cl, cu, new MatrixBlock());
//frame indexing
frame3 = frame1.slice(rl, ru, cl, cu, new FrameBlock());
}
else if( itype == IXType.LIX )
{
//data generation
double[][] B = getRandomMatrix(ru-rl+1, cu-cl+1, -10, 10, 0.9, 7);
//init data frame 2
ValueType[] lschema2 = new ValueType[cu-cl+1];
for( int j=cl; j<=cu; j++ )
lschema2[j-cl] = schema[j];
FrameBlock frame2 = new FrameBlock(lschema2);
Object[] row2 = new Object[lschema2.length];
for( int i=0; i<ru-rl+1; i++ ) {
for( int j=0; j<lschema2.length; j++ )
B[i][j] = UtilFunctions.objectToDouble(lschema2[j],
row2[j] = UtilFunctions.doubleToObject(lschema2[j], B[i][j]));
frame2.appendRow(row2);
}
//matrix indexing
MatrixBlock mbA = DataConverter.convertToMatrixBlock(A);
MatrixBlock mbB = DataConverter.convertToMatrixBlock(B);
mbC = mbA.leftIndexingOperations(mbB, rl, ru, cl, cu, new MatrixBlock(), UpdateType.COPY);
//frame indexing
frame3 = frame1.leftIndexingOperations(frame2, rl, ru, cl, cu, new FrameBlock());
}
//check basic meta data
if( frame3.getNumRows() != mbC.getNumRows() )
Assert.fail("Wrong number of rows: "+frame3.getNumRows()+", expected: "+mbC.getNumRows());
//check correct values
ValueType[] lschema = frame3.getSchema();
for( int i=0; i<ru-rl+1; i++ )
for( int j=0; j<lschema.length; j++ ) {
double tmp = UtilFunctions.objectToDouble(lschema[j], frame3.get(i, j));
if( tmp != mbC.quickGetValue(i, j) )
Assert.fail("Wrong get value for cell ("+i+","+j+"): "+tmp+", expected: "+mbC.quickGetValue(i, j));
}
}
catch(Exception ex) {
ex.printStackTrace();
throw new RuntimeException(ex);
}
}
}