blob: 09e0c59e2f055a509dc59c93b5c4496596fbe1ce [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
package org.apache.sysds.test.component.sparse;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestUtils;
* This is a sparse matrix block component test for sparse block get
* first index functionality. In order to achieve broad coverage, we
* test against GT, GTE, and LTE as well as different sparsity values.
public class SparseBlockGetFirstIndex extends AutomatedTestBase
private final static int rows = 571;
private final static int cols = 595;
private final static double sparsity1 = 0.09;
private final static double sparsity2 = 0.19;
private final static double sparsity3 = 0.29;
public enum IndexType {
public void setUp() {
public void testSparseBlockMCSR1GT() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.MCSR, sparsity1, IndexType.GT);
public void testSparseBlockMCSR2GT() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.MCSR, sparsity2, IndexType.GT);
public void testSparseBlockMCSR3GT() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.MCSR, sparsity3, IndexType.GT);
public void testSparseBlockMCSR1GTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.MCSR, sparsity1, IndexType.GTE);
public void testSparseBlockMCSR2GTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.MCSR, sparsity2, IndexType.GTE);
public void testSparseBlockMCSR3GTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.MCSR, sparsity3, IndexType.GTE);
public void testSparseBlockMCSR1LTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.MCSR, sparsity1, IndexType.LTE);
public void testSparseBlockMCSR2LTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.MCSR, sparsity2, IndexType.LTE);
public void testSparseBlockMCSR3LTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.MCSR, sparsity3, IndexType.LTE);
public void testSparseBlockCSR1GT() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.CSR, sparsity1, IndexType.GT);
public void testSparseBlockCSR2GT() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.CSR, sparsity2, IndexType.GT);
public void testSparseBlockCSR3GT() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.CSR, sparsity3, IndexType.GT);
public void testSparseBlockCSR1GTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.CSR, sparsity1, IndexType.GTE);
public void testSparseBlockCSR2GTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.CSR, sparsity2, IndexType.GTE);
public void testSparseBlockCSR3GTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.CSR, sparsity3, IndexType.GTE);
public void testSparseBlockCSR1LTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.CSR, sparsity1, IndexType.LTE);
public void testSparseBlockCSR2LTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.CSR, sparsity2, IndexType.LTE);
public void testSparseBlockCSR3LTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.CSR, sparsity3, IndexType.LTE);
public void testSparseBlockCOO1GT() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.COO, sparsity1, IndexType.GT);
public void testSparseBlockCOO2GT() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.COO, sparsity2, IndexType.GT);
public void testSparseBlockCOO3GT() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.COO, sparsity3, IndexType.GT);
public void testSparseBlockCOO1GTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.COO, sparsity1, IndexType.GTE);
public void testSparseBlockCOO2GTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.COO, sparsity2, IndexType.GTE);
public void testSparseBlockCOO3GTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.COO, sparsity3, IndexType.GTE);
public void testSparseBlockCOO1LTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.COO, sparsity1, IndexType.LTE);
public void testSparseBlockCOO2LTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.COO, sparsity2, IndexType.LTE);
public void testSparseBlockCOO3LTE() {
runSparseBlockGetFirstIndexTest(SparseBlock.Type.COO, sparsity3, IndexType.LTE);
private void runSparseBlockGetFirstIndexTest( SparseBlock.Type btype, double sparsity, IndexType itype)
//data generation
double[][] A = getRandomMatrix(rows, cols, -10, 10, sparsity, 3456);
//init sparse block
SparseBlock sblock = null;
MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
SparseBlock srtmp = mbtmp.getSparseBlock();
switch( btype ) {
case MCSR: sblock = new SparseBlockMCSR(srtmp); break;
case CSR: sblock = new SparseBlockCSR(srtmp); break;
case COO: sblock = new SparseBlockCOO(srtmp); break;
//check for correct number of non-zeros
int[] rnnz = new int[rows]; int nnz = 0;
for( int i=0; i<rows; i++ ) {
for( int j=0; j<cols; j++ )
rnnz[i] += (A[i][j]!=0) ? 1 : 0;
nnz += rnnz[i];
if( nnz != sblock.size() )"Wrong number of non-zeros: "+sblock.size()+", expected: "+nnz);
//check correct isEmpty return
for( int i=0; i<rows; i++ )
if( sblock.isEmpty(i) != (rnnz[i]==0) )"Wrong isEmpty(row) result for row nnz: "+rnnz[i]);
//check correct index values
for( int i=0; i<rows; i++ ) {
int ix = getFirstIx(A, i, i, itype);
int sixpos = -1;
switch( itype ) {
case GT: sixpos = sblock.posFIndexGT(i, i); break;
case GTE: sixpos = sblock.posFIndexGTE(i, i); break;
case LTE: sixpos = sblock.posFIndexLTE(i, i); break;
int six = (sixpos>=0) ?
sblock.indexes(i)[sblock.pos(i)+sixpos] : -1;
if( six != ix ) {"Wrong index returned by index probe ("+
itype.toString()+","+i+"): "+six+", expected: "+ix);
catch(Exception ex) {
throw new RuntimeException(ex);
private static int getFirstIx( double[][] A, int rix, int cix, IndexType type ) {
if( type==IndexType.GT ) {
for( int j=cix+1; j<cols; j++ )
if( A[rix][j] != 0 )
return j;
return -1;
else if( type==IndexType.GTE ) {
for( int j=cix; j<cols; j++ )
if( A[rix][j] != 0 )
return j;
return -1;
else if( type==IndexType.LTE ) {
for( int j=cix; j>=0; j-- )
if( A[rix][j] != 0 )
return j;
return -1;
return -1;