blob: 2b554a261f56b01883270b785c60aa6803312767 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.meta;
import org.apache.sysds.runtime.util.UtilFunctions;
import java.util.Arrays;
public class TensorCharacteristics extends DataCharacteristics
{
private static final long serialVersionUID = 8300479822915546000L;
// TODO move block size to `ConfigurationManager`
public static final int[] DEFAULT_BLOCK_SIZE = {1024, 128, 32, 16, 8, 8};
private long[] _dims;
private long _nnz = -1;
public TensorCharacteristics() {}
public TensorCharacteristics(long[] dims, long nnz) {
set(dims, DEFAULT_BLOCK_SIZE[dims.length - 2], nnz);
}
public TensorCharacteristics(long[] dims, int blocksize) {
set(dims, blocksize, -1);
}
public TensorCharacteristics(long[] dims, int blocksize, long nnz) {
set(dims, blocksize, nnz);
}
public TensorCharacteristics(DataCharacteristics that) {
set(that);
}
@Override
public DataCharacteristics set(long[] dims, int blocksize) {
set(dims, blocksize, -1);
return this;
}
@Override
public DataCharacteristics set(long[] dims, int blocksize, long nnz) {
_dims = dims;
_blocksize = blocksize;
return this;
}
@Override
public DataCharacteristics set(DataCharacteristics that) {
long[] dims = that.getDims().clone();
set(dims, that.getBlocksize(), that.getNonZeros());
return this;
}
@Override
public DataCharacteristics setNonZeros(long nnz) {
_nnz = nnz;
return this;
}
@Override
public boolean dimsKnown() {
for (long dim : _dims) {
if (dim < 0)
return false;
}
return true;
}
@Override
public boolean dimsKnown(boolean includeNnz) {
return dimsKnown() && (!includeNnz || nnzKnown());
}
@Override
public boolean nnzKnown() {
return _nnz >= 0;
}
@Override
public int getNumDims() {
return _dims.length;
}
@Override
public long getDim(int i) {
return _dims[i];
}
@Override
public long[] getLongDims() {
return _dims;
}
@Override
public int[] getIntDims() {
return Arrays.stream(getLongDims()).mapToInt(i -> (int) i).toArray();
}
@Override
public DataCharacteristics setDim(int i, long dim) {
_dims[i] = dim;
return this;
}
@Override
public DataCharacteristics setDims(long[] dims) {
_dims = dims;
return this;
}
@Override
public long getLength() {
return UtilFunctions.prod(_dims);
}
@Override
public long getNumBlocks() {
long ret = 1;
for( int i=0; i<getNumDims(); i++ )
ret *= getNumBlocks(i);
return ret;
}
@Override
public long getNumBlocks(int i) {
return Math.max((long) Math.ceil((double)getDim(i) / getBlocksize()), 1);
}
@Override
public long getNonZeros() {
return _nnz;
}
@Override
public String toString() {
return "["+Arrays.toString(_dims)+", nnz="+_nnz + ", blocksize= "+_blocksize+"]";
}
@Override
public boolean equalDims(Object anObject) {
if( !(anObject instanceof TensorCharacteristics) )
return false;
TensorCharacteristics tc = (TensorCharacteristics) anObject;
return dimsKnown() && tc.dimsKnown()
&& Arrays.equals(_dims, tc._dims);
}
@Override
public boolean equals (Object anObject) {
if( !(anObject instanceof TensorCharacteristics) )
return false;
TensorCharacteristics tc = (TensorCharacteristics) anObject;
return Arrays.equals(_dims, tc._dims)
&& _blocksize == tc._blocksize
&& _nnz == tc._nnz;
}
@Override
public int hashCode() {
return UtilFunctions.intHashCode(UtilFunctions.intHashCode(
Arrays.hashCode(_dims), _blocksize), Long.hashCode(_nnz));
}
}