blob: ce3e56c76efcef18c7455e98116ffc2a0a25f283 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.test.component.tensor;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.util.UtilFunctions;
import java.util.Arrays;
public class TensorGetSetIndexingTest
{
private static int DIM0 = 3, DIM1 = 5, DIM2 = 7;
// TODO large tensor tests
@Test
public void testIndexBasicTensor2FP32SetGetCell() {
TensorBlock tb = getBasicTensor2(ValueType.FP32);
checkSequence(setSequence(tb));
}
@Test
public void testIndexBasicTensor2FP64SetGetCell() {
TensorBlock tb = getBasicTensor2(ValueType.FP64);
checkSequence(setSequence(tb));
}
@Test
public void testIndexBasicTensor2BoolSetGetCell() {
TensorBlock tb = getBasicTensor2(ValueType.BOOLEAN);
checkSequence(setSequence(tb));
}
@Test
public void testIndexBasicTensor2Int32SetGetCell() {
TensorBlock tb = getBasicTensor2(ValueType.INT32);
checkSequence(setSequence(tb));
}
@Test
public void testIndexBasicTensor2Int64SetGetCell() {
TensorBlock tb = getBasicTensor2(ValueType.INT64);
checkSequence(setSequence(tb));
}
@Test
public void testIndexBasicTensor3FP32SetGetCell() {
TensorBlock tb = getBasicTensor3(ValueType.FP32);
checkSequence(setSequence(tb));
}
@Test
public void testIndexBasicTensor3FP64SetGetCell() {
TensorBlock tb = getBasicTensor3(ValueType.FP64);
checkSequence(setSequence(tb));
}
@Test
public void testIndexBasicTensor3BoolSetGetCell() {
TensorBlock tb = getBasicTensor3(ValueType.BOOLEAN);
checkSequence(setSequence(tb));
}
@Test
public void testIndexBasicTensor3Int32SetGetCell() {
TensorBlock tb = getBasicTensor3(ValueType.INT32);
checkSequence(setSequence(tb));
}
@Test
public void testIndexBasicTensor3Int64SetGetCell() {
TensorBlock tb = getBasicTensor3(ValueType.INT64);
checkSequence(setSequence(tb));
}
private static TensorBlock getBasicTensor2(ValueType vt) {
// Todo: implement sparse for Tensor
return new TensorBlock(vt, new int[] {DIM0,DIM1});
}
private static TensorBlock getBasicTensor3(ValueType vt) {
// Todo: implement sparse for Tensor
return new TensorBlock(vt, new int[] {DIM0,DIM1,DIM2});
}
private static TensorBlock setSequence(TensorBlock tb) {
if( tb.getNumDims() == DIM0 ) {
int dim12 = DIM1*DIM2;
for(int i=0; i<tb.getNumRows(); i++)
for(int j=0; j<DIM1; j++)
for(int k=0; k<DIM2; k++)
tb.set(new int[] {i,j,k}, (double)i*dim12+j*DIM2+k);
}
else { //num dims = 2
for(int i=0; i<tb.getNumRows(); i++)
for(int j=0; j<DIM1; j++)
tb.set(new int[]{i,j}, i*DIM1+j);
}
return tb;
}
private static void checkSequence(TensorBlock tb) {
boolean isBool = (tb.isBasic() ? tb.getValueType() : tb.getSchema()[0]) == ValueType.BOOLEAN;
if( tb.getNumDims() == DIM0 ) {
int dim12 = DIM1 * DIM2;
for(int i=0; i<tb.getNumRows(); i++)
for(int j=0; j<DIM1; j++)
for(int k=0; k<DIM2; k++) {
int val = i*dim12+j*DIM2+k;
double expected = isBool && val!=0 ? 1 : val;
Object actualObj = tb.get(new int[]{i, j, k});
ValueType vt = !tb.isBasic() ? tb.getSchema()[j] : tb.getValueType();
double actual = UtilFunctions.objectToDouble(vt, actualObj);
Assert.assertEquals(expected, actual, 0);
}
}
else { //num dims = 2
for(int i=0; i<tb.getNumRows(); i++)
for(int j=0; j<DIM1; j++) {
int val = i*DIM1+j;
double expected = isBool && val!=0 ? 1 : val;
ValueType vt = !tb.isBasic() ? tb.getSchema()[j] : tb.getValueType();
double actual = UtilFunctions.objectToDouble(
vt, tb.get(new int[]{i, j}));
Assert.assertEquals(expected, actual, 0);
}
}
}
@Test
public void testIndexDataTensor2FP32SetGetCell() {
TensorBlock tb = getDataTensor2(ValueType.FP32);
checkSequence(setSequence(tb));
}
@Test
public void testIndexDataTensor2FP64SetGetCell() {
TensorBlock tb = getDataTensor2(ValueType.FP64);
checkSequence(setSequence(tb));
}
@Test
public void testIndexDataTensor2BoolSetGetCell() {
TensorBlock tb = getDataTensor2(ValueType.BOOLEAN);
checkSequence(setSequence(tb));
}
@Test
public void testIndexDataTensor2Int32SetGetCell() {
TensorBlock tb = getDataTensor2(ValueType.INT32);
checkSequence(setSequence(tb));
}
@Test
public void testIndexDataTensor2Int64SetGetCell() {
TensorBlock tb = getDataTensor2(ValueType.INT64);
checkSequence(setSequence(tb));
}
@Test
public void testIndexDataTensor3FP32SetGetCell() {
TensorBlock tb = getDataTensor3(ValueType.FP32);
checkSequence(setSequence(tb));
}
@Test
public void testIndexDataTensor3FP64SetGetCell() {
TensorBlock tb = getDataTensor3(ValueType.FP64);
checkSequence(setSequence(tb));
}
@Test
public void testIndexDataTensor3BoolSetGetCell() {
TensorBlock tb = getDataTensor3(ValueType.BOOLEAN);
checkSequence(setSequence(tb));
}
@Test
public void testIndexDataTensor3Int32SetGetCell() {
TensorBlock tb = getDataTensor3(ValueType.INT32);
checkSequence(setSequence(tb));
}
@Test
public void testIndexDataTensor3Int64SetGetCell() {
TensorBlock tb = getDataTensor3(ValueType.INT64);
checkSequence(setSequence(tb));
}
private static TensorBlock getDataTensor2(ValueType vt) {
ValueType[] schema = new ValueType[DIM1];
Arrays.fill(schema, vt);
return new TensorBlock(schema, new int[] {DIM0,DIM1});
}
private static TensorBlock getDataTensor3(ValueType vt) {
ValueType[] schema = new ValueType[DIM1];
Arrays.fill(schema, vt);
return new TensorBlock(schema, new int[] {DIM0,DIM1,DIM2});
}
}