blob: 4869e3e97c778b67f73220f8b4b0ceaf53280f82 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.instructions.cp;
import org.apache.sysds.lops.Ctable;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.CTableMap;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.LongLongDoubleHashMap.EntryType;
public class CtableCPInstruction extends ComputationCPInstruction {
private final CPOperand _outDim1;
private final CPOperand _outDim2;
private final boolean _isExpand;
private final boolean _ignoreZeros;
private CtableCPInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
String outputDim1, boolean dim1Literal, String outputDim2, boolean dim2Literal, boolean isExpand,
boolean ignoreZeros, String opcode, String istr) {
super(CPType.Ctable, null, in1, in2, in3, out, opcode, istr);
_outDim1 = new CPOperand(outputDim1, ValueType.FP64, DataType.SCALAR, dim1Literal);
_outDim2 = new CPOperand(outputDim2, ValueType.FP64, DataType.SCALAR, dim2Literal);
_isExpand = isExpand;
_ignoreZeros = ignoreZeros;
}
public static CtableCPInstruction parseInstruction(String inst)
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst);
InstructionUtils.checkNumFields ( parts, 7 );
String opcode = parts[0];
//handle opcode
if ( !(opcode.equalsIgnoreCase("ctable") || opcode.equalsIgnoreCase("ctableexpand")) ) {
throw new DMLRuntimeException("Unexpected opcode in TertiaryCPInstruction: " + inst);
}
boolean isExpand = opcode.equalsIgnoreCase("ctableexpand");
//handle operands
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
//handle known dimension information
String[] dim1Fields = parts[4].split(Instruction.LITERAL_PREFIX);
String[] dim2Fields = parts[5].split(Instruction.LITERAL_PREFIX);
CPOperand out = new CPOperand(parts[6]);
boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
// ctable does not require any operator, so we simply pass-in a dummy operator with null functionobject
return new CtableCPInstruction(in1, in2, in3, out, dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0], Boolean.parseBoolean(dim2Fields[1]), isExpand, ignoreZeros, opcode, inst);
}
private Ctable.OperationTypes findCtableOperation() {
DataType dt1 = input1.getDataType();
DataType dt2 = input2.getDataType();
DataType dt3 = input3.getDataType();
return Ctable.findCtableOperationByInputDataTypes(dt1, dt2, dt3);
}
@Override
public void processInstruction(ExecutionContext ec) {
MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
MatrixBlock matBlock2=null, wtBlock=null;
double cst1, cst2;
CTableMap resultMap = new CTableMap(EntryType.INT);
MatrixBlock resultBlock = null;
Ctable.OperationTypes ctableOp = findCtableOperation();
ctableOp = _isExpand ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp;
long outputDim1 = ec.getScalarInput(_outDim1).getLongValue();
long outputDim2 = ec.getScalarInput(_outDim2).getLongValue();
boolean outputDimsKnown = (outputDim1 != -1 && outputDim2 != -1);
if ( outputDimsKnown ) {
int inputRows = matBlock1.getNumRows();
int inputCols = matBlock1.getNumColumns();
boolean sparse = MatrixBlock.evalSparseFormatInMemory(outputDim1, outputDim2, inputRows*inputCols);
//only create result block if dense; it is important not to aggregate on sparse result
//blocks because it would implicitly turn the O(N) algorithm into O(N log N).
if( !sparse )
resultBlock = new MatrixBlock((int)outputDim1, (int)outputDim2, false);
}
if( _isExpand ){
resultBlock = new MatrixBlock( matBlock1.getNumRows(), Integer.MAX_VALUE, true );
}
switch(ctableOp) {
case CTABLE_TRANSFORM: //(VECTOR)
// F=ctable(A,B,W)
matBlock2 = ec.getMatrixInput(input2.getName());
wtBlock = ec.getMatrixInput(input3.getName());
matBlock1.ctableOperations(_optr, matBlock2, wtBlock, resultMap, resultBlock);
break;
case CTABLE_TRANSFORM_SCALAR_WEIGHT: //(VECTOR/MATRIX)
// F = ctable(A,B) or F = ctable(A,B,1)
matBlock2 = ec.getMatrixInput(input2.getName());
cst1 = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
matBlock1.ctableOperations(_optr, matBlock2, cst1, _ignoreZeros, resultMap, resultBlock);
break;
case CTABLE_EXPAND_SCALAR_WEIGHT: //(VECTOR)
// F = ctable(seq,A) or F = ctable(seq,B,1)
matBlock2 = ec.getMatrixInput(input2.getName());
cst1 = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
// only resultBlock.rlen known, resultBlock.clen set in operation
matBlock1.ctableSeqOperations(matBlock2, cst1, resultBlock);
break;
case CTABLE_TRANSFORM_HISTOGRAM: //(VECTOR)
// F=ctable(A,1) or F = ctable(A,1,1)
cst1 = ec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue();
cst2 = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
matBlock1.ctableOperations(_optr, cst1, cst2, resultMap, resultBlock);
break;
case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: //(VECTOR)
// F=ctable(A,1,W)
wtBlock = ec.getMatrixInput(input3.getName());
cst1 = ec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue();
matBlock1.ctableOperations(_optr, cst1, wtBlock, resultMap, resultBlock);
break;
default:
throw new DMLRuntimeException("Encountered an invalid ctable operation ("+ctableOp+") while executing instruction: " + this.toString());
}
if(input1.getDataType() == DataType.MATRIX)
ec.releaseMatrixInput(input1.getName());
if(input2.getDataType() == DataType.MATRIX)
ec.releaseMatrixInput(input2.getName());
if(input3.getDataType() == DataType.MATRIX)
ec.releaseMatrixInput(input3.getName());
if ( resultBlock == null ){
//we need to respect potentially specified output dimensions here, because we might have
//decided for hash-aggregation just to prevent inefficiency in case of sparse outputs.
if( outputDimsKnown )
resultBlock = DataConverter.convertToMatrixBlock( resultMap, (int)outputDim1, (int)outputDim2 );
else
resultBlock = DataConverter.convertToMatrixBlock( resultMap );
}
else
resultBlock.examSparsity();
// Ensure right dense/sparse output representation for special cases
// such as ctable expand (guarded by released input memory)
if( checkGuardedRepresentationChange(matBlock1, matBlock2, resultBlock) ) {
resultBlock.examSparsity();
}
ec.setMatrixOutput(output.getName(), resultBlock);
}
@Override
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
LineageItem[] linputs = !(_outDim1.getName().equals("-1") && _outDim2.getName().equals("-1")) ?
LineageItemUtils.getLineage(ec, input1, input2, input3, _outDim1, _outDim2) :
LineageItemUtils.getLineage(ec, input1, input2, input3);
return Pair.of(output.getName(), new LineageItem(getOpcode(), linputs));
}
}