blob: 3d112a57b036e1fa4cf9460a579b2f5424e19cd1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.compress.lib;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.IContainDefaultTuple;
import org.apache.sysds.runtime.compress.colgroup.IFrameOfReferenceGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.estim.encoding.ConstEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.DenseEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.EmptyEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.compress.estim.encoding.SparseEncoding;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
/**
* Library functions to combine column groups inside a compressed matrix.
*/
public final class CLALibCombineGroups {
protected static final Log LOG = LogFactory.getLog(CLALibCombineGroups.class.getName());
private CLALibCombineGroups() {
// private constructor
}
public static List<AColGroup> combine(CompressedMatrixBlock cmb, int k) {
ExecutorService pool = null;
try {
pool = (k > 1) ? CommonThreadPool.get(k) : null;
return combine(cmb, null, pool);
}
catch(Exception e) {
throw new DMLCompressionException("Compression Failed", e);
}
finally {
if(pool != null)
pool.shutdown();
}
}
public static List<AColGroup> combine(CompressedMatrixBlock cmb, CompressedSizeInfo csi, ExecutorService pool) {
List<AColGroup> input = cmb.getColGroups();
boolean filterFor = CLALibUtils.shouldFilterFOR(input);
double[] c = filterFor ? new double[cmb.getNumColumns()] : null;
if(filterFor)
input = CLALibUtils.filterFOR(input, c);
List<List<AColGroup>> combinations = new ArrayList<>();
for(CompressedSizeInfoColGroup gi : csi.getInfo()) {
combinations.add(findGroupsInIndex(gi.getColumns(), input));
}
List<AColGroup> ret = new ArrayList<>();
if(filterFor)
for(List<AColGroup> combine : combinations)
ret.add(combineN(combine).addVector(c));
else
for(List<AColGroup> combine : combinations)
ret.add(combineN(combine));
return ret;
}
public static List<AColGroup> findGroupsInIndex(IColIndex idx, List<AColGroup> groups) {
List<AColGroup> ret = new ArrayList<>();
for(AColGroup g : groups) {
if(g.getColIndices().containsAny(idx))
ret.add(g);
}
return ret;
}
public static AColGroup combineN(List<AColGroup> groups) {
AColGroup base = groups.get(0);
// Inefficient combine N but base line
for(int i = 1; i < groups.size(); i++) {
base = combine(base, groups.get(i));
}
return base;
}
/**
* Combine the column groups A and B together.
*
* The number of rows should be equal, and it is not verified so there will be unexpected behavior in such cases.
*
* It is assumed that this method is not called with FOR groups
*
* @param a The first group to combine.
* @param b The second group to combine.
* @return A new column group containing the two.
*/
public static AColGroup combine(AColGroup a, AColGroup b) {
if(a instanceof IFrameOfReferenceGroup || b instanceof IFrameOfReferenceGroup)
throw new DMLCompressionException("Invalid call with frame of reference group to combine");
IColIndex combinedColumns = ColIndexFactory.combine(a, b);
// try to recompress a and b if uncompressed
if(a instanceof ColGroupUncompressed)
a = a.recompress();
if(b instanceof ColGroupUncompressed)
b = b.recompress();
if(a instanceof AColGroupCompressed && b instanceof AColGroupCompressed)
return combineCompressed(combinedColumns, (AColGroupCompressed) a, (AColGroupCompressed) b);
else if(a instanceof ColGroupUncompressed || b instanceof ColGroupUncompressed)
// either side is uncompressed
return combineUC(combinedColumns, a, b);
throw new NotImplementedException(
"Not implemented combine for " + a.getClass().getSimpleName() + " - " + b.getClass().getSimpleName());
}
private static AColGroup combineCompressed(IColIndex combinedColumns, AColGroupCompressed ac,
AColGroupCompressed bc) {
IEncode ae = ac.getEncoding();
IEncode be = bc.getEncoding();
if(ae instanceof SparseEncoding && !(be instanceof SparseEncoding)) {
// the order must be sparse second unless both sparse.
return combineCompressed(combinedColumns, bc, ac);
}
Pair<IEncode, Map<Integer,Integer>> cec = ae.combineWithMap(be);
IEncode ce = cec.getLeft();
Map<Integer,Integer> filter = cec.getRight();
if(ce instanceof DenseEncoding) {
DenseEncoding ced = (DenseEncoding) (ce);
ADictionary cd = DictionaryFactory.combineDictionaries(ac, bc, filter);
return ColGroupDDC.create(combinedColumns, cd, ced.getMap(), null);
}
else if(ce instanceof EmptyEncoding) {
return new ColGroupEmpty(combinedColumns);
}
else if(ce instanceof ConstEncoding) {
ADictionary cd = DictionaryFactory.combineDictionaries(ac, bc, filter);
return ColGroupConst.create(combinedColumns, cd);
}
else if(ce instanceof SparseEncoding) {
SparseEncoding sed = (SparseEncoding) ce;
ADictionary cd = DictionaryFactory.combineDictionariesSparse(ac, bc);
double[] defaultTuple = constructDefaultTuple((AColGroupCompressed) ac, (AColGroupCompressed) bc);
return ColGroupSDC.create(combinedColumns, sed.getNumRows(), cd, defaultTuple, sed.getOffsets(), sed.getMap(),
null);
}
throw new NotImplementedException(
"Not implemented combine for " + ac.getClass().getSimpleName() + " - " + bc.getClass().getSimpleName());
}
private static AColGroup combineUC(IColIndex combinedColumns, AColGroup a, AColGroup b) {
int nRow = a instanceof ColGroupUncompressed ? //
((ColGroupUncompressed) a).getData().getNumRows() : //
((ColGroupUncompressed) b).getData().getNumRows();
// step 1 decompress both into target uncompressed MatrixBlock;
MatrixBlock target = new MatrixBlock(nRow, combinedColumns.size(), false);
target.allocateBlock();
DenseBlock db = target.getDenseBlock();
IColIndex aTempCols = ColIndexFactory.getColumnMapping(combinedColumns, a.getColIndices());
a.copyAndSet(aTempCols).decompressToDenseBlock(db, 0, nRow, 0, 0);
IColIndex bTempCols = ColIndexFactory.getColumnMapping(combinedColumns, b.getColIndices());
b.copyAndSet(bTempCols).decompressToDenseBlock(db, 0, nRow, 0, 0);
target.recomputeNonZeros();
return ColGroupUncompressed.create(combinedColumns, target, false);
}
public static double[] constructDefaultTuple(AColGroupCompressed ac, AColGroupCompressed bc) {
double[] ret = new double[ac.getNumCols() + bc.getNumCols()];
if(ac instanceof IContainDefaultTuple) {
double[] defa = ((IContainDefaultTuple) ac).getDefaultTuple();
System.arraycopy(defa, 0, ret, 0, defa.length);
}
if(bc instanceof IContainDefaultTuple) {
double[] defb = ((IContainDefaultTuple) bc).getDefaultTuple();
System.arraycopy(defb, 0, ret, ac.getNumCols(), defb.length);
}
return ret;
}
}