blob: 449cc2d2974ca12937f5eaa75e56723e8bb8cb16 [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
package org.apache.sysds.runtime.meta;
import org.apache.sysds.runtime.util.UtilFunctions;
import java.util.Arrays;
public class TensorCharacteristics extends DataCharacteristics
private static final long serialVersionUID = 8300479822915546000L;
// TODO move block size to `ConfigurationManager`
public static final int[] DEFAULT_BLOCK_SIZE = {1024, 128, 32, 16, 8, 8};
private long[] _dims;
private long _nnz = -1;
public TensorCharacteristics() {}
public TensorCharacteristics(long[] dims, long nnz) {
set(dims, DEFAULT_BLOCK_SIZE[dims.length - 2], nnz);
public TensorCharacteristics(long[] dims, int blocksize) {
set(dims, blocksize, -1);
public TensorCharacteristics(long[] dims, int blocksize, long nnz) {
set(dims, blocksize, nnz);
public TensorCharacteristics(DataCharacteristics that) {
public DataCharacteristics set(long[] dims, int blocksize) {
set(dims, blocksize, -1);
return this;
public DataCharacteristics set(long[] dims, int blocksize, long nnz) {
_dims = dims;
_blocksize = blocksize;
return this;
public DataCharacteristics set(DataCharacteristics that) {
long[] dims = that.getDims().clone();
set(dims, that.getBlocksize(), that.getNonZeros());
return this;
public DataCharacteristics setNonZeros(long nnz) {
_nnz = nnz;
return this;
public boolean dimsKnown() {
for (long dim : _dims) {
if (dim < 0)
return false;
return true;
public boolean dimsKnown(boolean includeNnz) {
return dimsKnown() && (!includeNnz || nnzKnown());
public boolean nnzKnown() {
return _nnz >= 0;
public int getNumDims() {
return _dims.length;
public long getDim(int i) {
return _dims[i];
public long[] getLongDims() {
return _dims;
public int[] getIntDims() {
return -> (int) i).toArray();
public DataCharacteristics setDim(int i, long dim) {
_dims[i] = dim;
return this;
public DataCharacteristics setDims(long[] dims) {
_dims = dims;
return this;
public long getLength() {
public long getNumBlocks() {
long ret = 1;
for( int i=0; i<getNumDims(); i++ )
ret *= getNumBlocks(i);
return ret;
public long getNumBlocks(int i) {
return Math.max((long) Math.ceil((double)getDim(i) / getBlocksize()), 1);
public long getNonZeros() {
return _nnz;
public String toString() {
return "["+Arrays.toString(_dims)+", nnz="+_nnz + ", blocksize= "+_blocksize+"]";
public boolean equals (Object anObject) {
if( !(anObject instanceof TensorCharacteristics) )
return false;
TensorCharacteristics tc = (TensorCharacteristics) anObject;
return Arrays.equals(_dims, tc._dims)
&& _blocksize == tc._blocksize
&& _nnz == tc._nnz;
public int hashCode() {
return UtilFunctions.intHashCode(UtilFunctions.intHashCode(
Arrays.hashCode(_dims), _blocksize), Long.hashCode(_nnz));