blob: 01239f760ab9dc84c3bbf7ab42e0eb60e944de1c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.compress.lib;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLCompressionException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.ColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.Dictionary;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
import org.apache.sysds.runtime.matrix.data.LibMatrixBincell.BinaryAccessType;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
public class LibBinaryCellOp {
private static final Log LOG = LogFactory.getLog(LibBinaryCellOp.class.getName());
/**
* matrix-matrix binary operations, MM, MV
*
* @param m1 Input matrix in compressed format to perform the operator on using m2
* @param m2 Input matrix 2 to use cell-wise on m1
* @param ret Result matrix to be returned in the operation.
* @param op Binary operator such as multiply, add, less than, etc.
* @return The ret matrix, modified appropriately.
*/
public static MatrixBlock bincellOp(CompressedMatrixBlock m1, MatrixBlock m2, CompressedMatrixBlock ret,
BinaryOperator op) {
if(op.fn instanceof Minus) {
ScalarOperator sop = new RightScalarOperator(Multiply.getMultiplyFnObject(), -1);
m2 = m2.scalarOperations(sop, new MatrixBlock());
return LibBinaryCellOp.bincellOp(m1, m2, ret, new BinaryOperator(Plus.getPlusFnObject()));
}
if(m1.isOverlapping() && !(op.fn instanceof Multiply)) {
if(op.fn instanceof Plus || op.fn instanceof Minus) {
return binaryMVPlusStack(m1, m2, ret, op);
}
else {
throw new NotImplementedException(op + " not implemented for CLA");
}
}
else {
BinaryAccessType atype = LibMatrixBincell.getBinaryAccessType(m1, m2);
switch(atype) {
case MATRIX_ROW_VECTOR:
// Verify if it is okay to include all OuterVectorVector ops here.
return binaryMVRow(m1, m2, ret, op);
case OUTER_VECTOR_VECTOR:
if(m2.getNumRows() == 1 && m2.getNumColumns() == 1) {
return LibScalar.scalarOperations(new RightScalarOperator(op.fn, m2.quickGetValue(0, 0)),
m1,
ret,
m1.isOverlapping());
}
default:
LOG.warn("Inefficient Decompression for " + op + " " + atype);
MatrixBlock m1d = m1.decompress();
return m1d.binaryOperations(op, m2, ret);
}
}
}
protected static CompressedMatrixBlock binaryMVRow(CompressedMatrixBlock m1, MatrixBlock m2,
CompressedMatrixBlock ret, BinaryOperator op) {
// Apply the operation to each of the column groups.
// Most implementations will only modify metadata.
List<ColGroup> oldColGroups = m1.getColGroups();
List<ColGroup> newColGroups = new ArrayList<>(oldColGroups.size());
double[] v = m2.getDenseBlockValues();
boolean sparseSafe = true;
for(double x : v) {
if(op.fn.execute(x, 0.0) != 0.0) {
sparseSafe = false;
break;
}
}
for(ColGroup grp : oldColGroups) {
if(grp instanceof ColGroupUncompressed) {
throw new DMLCompressionException("Not supported Binary MV");
}
else {
if(grp.getNumCols() == 1) {
ScalarOperator sop = new LeftScalarOperator(op.fn, m2.getValue(0, grp.getColIndices()[0]), 1);
newColGroups.add(grp.scalarOperation(sop));
}
else {
ColGroup ncg = grp.binaryRowOp(op, v, sparseSafe);
newColGroups.add(ncg);
}
}
}
ret.allocateColGroupList(newColGroups);
ret.setNonZeros(m1.getNumColumns() * m1.getNumRows());
return ret;
}
protected static CompressedMatrixBlock binaryMVPlusStack(CompressedMatrixBlock m1, MatrixBlock m2,
CompressedMatrixBlock ret, BinaryOperator op) {
List<ColGroup> oldColGroups = m1.getColGroups();
List<ColGroup> newColGroups = new ArrayList<>(oldColGroups.size() + 1);
for(ColGroup grp : m1.getColGroups()) {
newColGroups.add(grp);
}
int[] colIndexes = oldColGroups.get(0).getColIndices();
double[] v = m2.getDenseBlockValues();
ADictionary newDict = new Dictionary(new double[colIndexes.length]);
newDict = newDict.applyBinaryRowOp(op.fn, v, true, colIndexes);
newColGroups.add(new ColGroupConst(colIndexes, m1.getNumRows(), newDict));
ret.allocateColGroupList(newColGroups);
ret.setOverlapping(true);
ret.setNonZeros(-1);
return ret;
}
}