blob: d8fcd47cf4b580af4a3631561d40235e14aec74c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.transform.encode;
import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.DependencyTask;
import org.apache.sysds.utils.Statistics;
public class ColumnEncoderDummycode extends ColumnEncoder {
private static final long serialVersionUID = 5832130477659116489L;
public int _domainSize = -1; // length = #of dummycoded columns
public ColumnEncoderDummycode() {
super(-1);
}
public ColumnEncoderDummycode(int colID) {
super(colID);
}
public ColumnEncoderDummycode(int colID, int domainSize) {
super(colID);
_domainSize = domainSize;
}
@Override
protected TransformType getTransformType() {
return TransformType.DUMMYCODE;
}
@Override
public void build(CacheBlock in) {
// do nothing
}
@Override
public List<DependencyTask<?>> getBuildTasks(CacheBlock in) {
return null;
}
@Override
protected double getCode(CacheBlock in, int row) {
throw new DMLRuntimeException("DummyCoder does not have a code");
}
protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){
if (!(in instanceof MatrixBlock)){
throw new DMLRuntimeException("ColumnEncoderDummycode called with: " + in.getClass().getSimpleName() +
" and not MatrixBlock");
}
Set<Integer> sparseRowsWZeros = null;
int index = _colID - 1;
for(int r = rowStart; r < getEndIndex(in.getNumRows(), rowStart, blk); r++) {
// Since the recoded values are already offset in the output matrix (same as input at this point)
// the dummycoding only needs to offset them within their column domain. Which means that the
// indexes in the SparseRowVector do not need to be sorted anymore and can be updated directly.
//
// Input: Output:
//
// 1 | 0 | 2 | 0 1 | 0 | 0 | 1
// 2 | 0 | 1 | 0 ===> 0 | 1 | 1 | 0
// 1 | 0 | 2 | 0 1 | 0 | 0 | 1
// 1 | 0 | 1 | 0 1 | 0 | 1 | 0
//
// Example SparseRowVector Internals (1. row):
//
// indexes = [0,2] ===> indexes = [0,3]
// values = [1,2] values = [1,1]
double val = out.getSparseBlock().get(r).values()[index];
if(Double.isNaN(val)){
if(sparseRowsWZeros == null)
sparseRowsWZeros = new HashSet<>();
sparseRowsWZeros.add(r);
out.getSparseBlock().get(r).values()[index] = 0;
continue;
}
int nCol = outputCol + (int) val - 1;
out.getSparseBlock().get(r).indexes()[index] = nCol;
out.getSparseBlock().get(r).values()[index] = 1;
}
if(sparseRowsWZeros != null){
addSparseRowsWZeros(sparseRowsWZeros);
}
}
protected void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){
if (!(in instanceof MatrixBlock)){
throw new DMLRuntimeException("ColumnEncoderDummycode called with: " + in.getClass().getSimpleName() +
" and not MatrixBlock");
}
for(int i = rowStart; i < getEndIndex(in.getNumRows(), rowStart, blk); i++) {
// Using outputCol here as index since we have a MatrixBlock as input where dummycoding could have been
// applied in a previous encoder
double val = in.getDouble(i, outputCol);
if(Double.isNaN(val)){
// 0 if NaN
out.quickSetValue(i, outputCol, 0);
continue;
}
int nCol = outputCol + (int) val - 1;
if(nCol != outputCol)
out.quickSetValue(i, outputCol, 0);
out.quickSetValue(i, nCol, 1);
}
}
@Override
protected ColumnApplyTask<? extends ColumnEncoder>
getSparseTask(CacheBlock in, MatrixBlock out, int outputCol, int startRow, int blk) {
if (!(in instanceof MatrixBlock)){
throw new DMLRuntimeException("ColumnEncoderDummycode called with: " + in.getClass().getSimpleName() +
" and not MatrixBlock");
}
return new DummycodeSparseApplyTask(this, (MatrixBlock) in, out, outputCol, startRow, blk);
}
@Override
public void mergeAt(ColumnEncoder other) {
if(other instanceof ColumnEncoderDummycode) {
assert other._colID == _colID;
// temporary, will be updated later
_domainSize = 0;
return;
}
super.mergeAt(other);
}
@Override
public void updateIndexRanges(long[] beginDims, long[] endDims, int colOffset) {
// new columns inserted in this (federated) block
beginDims[1] += colOffset;
endDims[1] += _domainSize - 1 + colOffset;
}
public void updateDomainSizes(List<ColumnEncoder> columnEncoders) {
if(_colID == -1)
return;
for(ColumnEncoder columnEncoder : columnEncoders) {
int distinct = -1;
if(columnEncoder instanceof ColumnEncoderRecode) {
ColumnEncoderRecode columnEncoderRecode = (ColumnEncoderRecode) columnEncoder;
distinct = columnEncoderRecode.getNumDistinctValues();
}
else if(columnEncoder instanceof ColumnEncoderBin) {
distinct = ((ColumnEncoderBin) columnEncoder)._numBin;
}
else if(columnEncoder instanceof ColumnEncoderFeatureHash){
distinct = (int) ((ColumnEncoderFeatureHash) columnEncoder).getK();
}
if(distinct != -1) {
_domainSize = distinct;
LOG.debug("DummyCoder for column: " + _colID + " has domain size: " + _domainSize);
}
}
}
@Override
public FrameBlock getMetaData(FrameBlock meta) {
return meta;
}
@Override
public void initMetaData(FrameBlock meta) {
// initialize domain sizes and output num columns
_domainSize = -1;
_domainSize = (int) meta.getColumnMetadata()[_colID - 1].getNumDistinct();
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
super.writeExternal(out);
out.writeInt(_domainSize);
}
@Override
public void readExternal(ObjectInput in) throws IOException {
super.readExternal(in);
_domainSize = in.readInt();
}
@Override
public boolean equals(Object o) {
if(this == o)
return true;
if(o == null || getClass() != o.getClass())
return false;
ColumnEncoderDummycode that = (ColumnEncoderDummycode) o;
return _colID == that._colID && (_domainSize == that._domainSize);
}
@Override
public int hashCode() {
int result = Objects.hash(_colID);
result = 31 * result + Objects.hashCode(_domainSize);
return result;
}
public int getDomainSize() {
return _domainSize;
}
private static class DummycodeSparseApplyTask extends ColumnApplyTask<ColumnEncoderDummycode> {
protected DummycodeSparseApplyTask(ColumnEncoderDummycode encoder, MatrixBlock input,
MatrixBlock out, int outputCol) {
super(encoder, input, out, outputCol);
}
protected DummycodeSparseApplyTask(ColumnEncoderDummycode encoder, MatrixBlock input,
MatrixBlock out, int outputCol, int startRow, int blk) {
super(encoder, input, out, outputCol, startRow, blk);
}
public Object call() throws Exception {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
if(_out.getSparseBlock() == null)
return null;
_encoder.applySparse(_input, _out, _outputCol, _startRow, _blk);
if (DMLScript.STATISTICS)
Statistics.incTransformDummyCodeApplyTime(System.nanoTime()-t0);
return null;
}
@Override
public String toString() {
return getClass().getSimpleName() + "<ColId: " + _encoder._colID + ">";
}
}
}