blob: 904f46dfbeb8cf6701f687f48446083d96770caa [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.instructions;
import java.util.StringTokenizer;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.CorrectionLocationType;
import org.apache.sysds.common.Types.Direction;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.WeightedCrossEntropy;
import org.apache.sysds.lops.WeightedCrossEntropyR;
import org.apache.sysds.lops.WeightedDivMM;
import org.apache.sysds.lops.WeightedDivMMR;
import org.apache.sysds.lops.WeightedSigmoid;
import org.apache.sysds.lops.WeightedSigmoidR;
import org.apache.sysds.lops.WeightedSquaredLoss;
import org.apache.sysds.lops.WeightedSquaredLossR;
import org.apache.sysds.lops.WeightedUnaryMM;
import org.apache.sysds.lops.WeightedUnaryMMR;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.functionobjects.And;
import org.apache.sysds.runtime.functionobjects.BitwAnd;
import org.apache.sysds.runtime.functionobjects.BitwOr;
import org.apache.sysds.runtime.functionobjects.BitwShiftL;
import org.apache.sysds.runtime.functionobjects.BitwShiftR;
import org.apache.sysds.runtime.functionobjects.BitwXor;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.CM;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Equals;
import org.apache.sysds.runtime.functionobjects.GreaterThan;
import org.apache.sysds.runtime.functionobjects.GreaterThanEquals;
import org.apache.sysds.runtime.functionobjects.IfElse;
import org.apache.sysds.runtime.functionobjects.IndexFunction;
import org.apache.sysds.runtime.functionobjects.IntegerDivide;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
import org.apache.sysds.runtime.functionobjects.LessThan;
import org.apache.sysds.runtime.functionobjects.LessThanEquals;
import org.apache.sysds.runtime.functionobjects.Mean;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Minus1Multiply;
import org.apache.sysds.runtime.functionobjects.MinusMultiply;
import org.apache.sysds.runtime.functionobjects.MinusNz;
import org.apache.sysds.runtime.functionobjects.Modulus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Multiply2;
import org.apache.sysds.runtime.functionobjects.Not;
import org.apache.sysds.runtime.functionobjects.NotEquals;
import org.apache.sysds.runtime.functionobjects.Or;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.functionobjects.PlusMultiply;
import org.apache.sysds.runtime.functionobjects.Power;
import org.apache.sysds.runtime.functionobjects.Power2;
import org.apache.sysds.runtime.functionobjects.ReduceAll;
import org.apache.sysds.runtime.functionobjects.ReduceCol;
import org.apache.sysds.runtime.functionobjects.ReduceDiag;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.functionobjects.Xor;
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FEDType;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction.GPUINSTRUCTION_TYPE;
import org.apache.sysds.runtime.instructions.spark.SPInstruction.SPType;
import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
public class InstructionUtils
{
//thread-local string builders for instruction concatenation (avoid allocation)
private static ThreadLocal<StringBuilder> _strBuilders = new ThreadLocal<StringBuilder>() {
@Override
protected StringBuilder initialValue() {
return new StringBuilder(64);
}
};
public static int checkNumFields( String str, int expected ) {
//note: split required for empty tokens
int numParts = str.split(Instruction.OPERAND_DELIM).length;
int numFields = numParts - 2; // -2 accounts for execType and opcode
if ( numFields != expected )
throw new DMLRuntimeException("checkNumFields() for (" + str + ") -- expected number (" + expected + ") != is not equal to actual number (" + numFields + ").");
return numFields;
}
public static int checkNumFields( String[] parts, int expected ) {
int numParts = parts.length;
int numFields = numParts - 1; //account for opcode
if ( numFields != expected )
throw new DMLRuntimeException("checkNumFields() -- expected number (" + expected + ") != is not equal to actual number (" + numFields + ").");
return numFields;
}
public static int checkNumFields( String[] parts, int expected1, int expected2 ) {
int numParts = parts.length;
int numFields = numParts - 1; //account for opcode
if ( numFields != expected1 && numFields != expected2 )
throw new DMLRuntimeException("checkNumFields() -- expected number (" + expected1 + " or "+ expected2 +") != is not equal to actual number (" + numFields + ").");
return numFields;
}
public static int checkNumFields( String str, int expected1, int expected2 ) {
//note: split required for empty tokens
int numParts = str.split(Instruction.OPERAND_DELIM).length;
int numFields = numParts - 2; // -2 accounts for execType and opcode
if ( numFields != expected1 && numFields != expected2 )
throw new DMLRuntimeException("checkNumFields() for (" + str + ") -- expected number (" + expected1 + " or "+ expected2 +") != is not equal to actual number (" + numFields + ").");
return numFields;
}
/**
* Given an instruction string, strip-off the execution type and return
* opcode and all input/output operands WITHOUT their data/value type.
* i.e., ret.length = parts.length-1 (-1 for execution type)
*
* @param str instruction string
* @return instruction parts as string array
*/
public static String[] getInstructionParts( String str ) {
StringTokenizer st = new StringTokenizer( str, Instruction.OPERAND_DELIM );
String[] ret = new String[st.countTokens()-1];
st.nextToken(); // stripping-off the exectype
ret[0] = st.nextToken(); // opcode
int index = 1;
while( st.hasMoreTokens() ){
String tmp = st.nextToken();
int ix = tmp.indexOf(Instruction.DATATYPE_PREFIX);
ret[index++] = tmp.substring(0,((ix>=0)?ix:tmp.length()));
}
return ret;
}
/**
* Given an instruction string, this function strips-off the
* execution type (CP or MR) and returns the remaining parts,
* which include the opcode as well as the input and output operands.
* Each returned part will have the datatype and valuetype associated
* with the operand.
*
* This function is invoked mainly for parsing CPInstructions.
*
* @param str instruction string
* @return instruction parts as string array
*/
public static String[] getInstructionPartsWithValueType( String str ) {
//note: split required for empty tokens
String[] parts = str.split(Instruction.OPERAND_DELIM, -1);
String[] ret = new String[parts.length-1]; // stripping-off the exectype
ret[0] = parts[1]; // opcode
for( int i=1; i<parts.length; i++ )
ret[i-1] = parts[i];
return ret;
}
public static ExecType getExecType( String str ) {
int ix = str.indexOf(Instruction.OPERAND_DELIM);
return ExecType.valueOf(str.substring(0, ix));
}
public static String getOpCode( String str ) {
int ix1 = str.indexOf(Instruction.OPERAND_DELIM);
int ix2 = str.indexOf(Instruction.OPERAND_DELIM, ix1+1);
return str.substring(ix1+1, ix2);
}
public static SPType getSPType(String str) {
return SPInstructionParser.String2SPInstructionType.get(getOpCode(str));
}
public static CPType getCPType(String str) {
return CPInstructionParser.String2CPInstructionType.get(getOpCode(str));
}
public static SPType getSPTypeByOpcode(String opcode) {
return SPInstructionParser.String2SPInstructionType.get(opcode);
}
public static CPType getCPTypeByOpcode( String opcode ) {
return CPInstructionParser.String2CPInstructionType.get(opcode);
}
public static GPUINSTRUCTION_TYPE getGPUType( String str ) {
return GPUInstructionParser.String2GPUInstructionType.get(getOpCode(str));
}
public static FEDType getFEDType(String str) {
return FEDInstructionParser.String2FEDInstructionType.get(getOpCode(str));
}
public static boolean isBuiltinFunction( String opcode ) {
Builtin.BuiltinCode bfc = Builtin.String2BuiltinCode.get(opcode);
return (bfc != null);
}
public static boolean isUnaryMetadata(String opcode) {
return opcode != null
&& (opcode.equals("nrow") || opcode.equals("ncol"));
}
public static AggregateUnaryOperator parseBasicAggregateUnaryOperator(String opcode) {
return parseBasicAggregateUnaryOperator(opcode, 1);
}
public static AggregateUnaryOperator parseBasicAggregateUnaryOperator(String opcode, int numThreads)
{
AggregateUnaryOperator aggun = null;
if ( opcode.equalsIgnoreCase("uak+") ) {
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uark+") ) { // RowSums
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uack+") ) { // ColSums
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.LASTROW);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uasqk+") ) {
AggregateOperator agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uarsqk+") ) {
// RowSums
AggregateOperator agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uacsqk+") ) {
// ColSums
AggregateOperator agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), CorrectionLocationType.LASTROW);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uamean") ) {
// Mean
AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), CorrectionLocationType.LASTTWOCOLUMNS);
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uarmean") ) {
// RowMeans
AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), CorrectionLocationType.LASTTWOCOLUMNS);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uacmean") ) {
// ColMeans
AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), CorrectionLocationType.LASTTWOROWS);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uavar") ) {
// Variance
CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
CorrectionLocationType cloc = CorrectionLocationType.LASTFOURCOLUMNS;
AggregateOperator agg = new AggregateOperator(0, varFn, cloc);
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uarvar") ) {
// RowVariances
CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
CorrectionLocationType cloc = CorrectionLocationType.LASTFOURCOLUMNS;
AggregateOperator agg = new AggregateOperator(0, varFn, cloc);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uacvar") ) {
// ColVariances
CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
CorrectionLocationType cloc = CorrectionLocationType.LASTFOURROWS;
AggregateOperator agg = new AggregateOperator(0, varFn, cloc);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("ua+") ) {
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uar+") ) {
// RowSums
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uac+") ) {
// ColSums
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("ua*") ) {
AggregateOperator agg = new AggregateOperator(1, Multiply.getMultiplyFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uar*") ) {
AggregateOperator agg = new AggregateOperator(1, Multiply.getMultiplyFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uac*") ) {
AggregateOperator agg = new AggregateOperator(1, Multiply.getMultiplyFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uamax") ) {
AggregateOperator agg = new AggregateOperator(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject("max"));
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uamin") ) {
AggregateOperator agg = new AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject("min"));
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uatrace") ) {
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceDiag.getReduceDiagFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uaktrace") ) {
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceDiag.getReduceDiagFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uarmax") ) {
AggregateOperator agg = new AggregateOperator(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject("max"));
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads);
}
else if (opcode.equalsIgnoreCase("uarimax") ) {
AggregateOperator agg = new AggregateOperator(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject("maxindex"), CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uarmin") ) {
AggregateOperator agg = new AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject("min"));
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads);
}
else if (opcode.equalsIgnoreCase("uarimin") ) {
AggregateOperator agg = new AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject("minindex"), CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uacmax") ) {
AggregateOperator agg = new AggregateOperator(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject("max"));
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads);
}
else if ( opcode.equalsIgnoreCase("uacmin") ) {
AggregateOperator agg = new AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject("min"));
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads);
}
return aggun;
}
public static AggregateTernaryOperator parseAggregateTernaryOperator(String opcode) {
return parseAggregateTernaryOperator(opcode, 1);
}
public static AggregateTernaryOperator parseAggregateTernaryOperator(String opcode, int numThreads) {
CorrectionLocationType corr = opcode.equalsIgnoreCase("tak+*") ?
CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.LASTROW;
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), corr);
IndexFunction ixfun = opcode.equalsIgnoreCase("tak+*") ?
ReduceAll.getReduceAllFnObject() : ReduceRow.getReduceRowFnObject();
return new AggregateTernaryOperator(Multiply.getMultiplyFnObject(), agg, ixfun, numThreads);
}
public static AggregateOperator parseAggregateOperator(String opcode, String corrLoc)
{
AggregateOperator agg = null;
if ( opcode.equalsIgnoreCase("ak+") || opcode.equalsIgnoreCase("aktrace") ) {
CorrectionLocationType lcorrLoc = (corrLoc==null) ? CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.valueOf(corrLoc);
agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), lcorrLoc);
}
else if ( opcode.equalsIgnoreCase("asqk+") ) {
CorrectionLocationType lcorrLoc = (corrLoc==null) ? CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.valueOf(corrLoc);
agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), lcorrLoc);
}
else if ( opcode.equalsIgnoreCase("a+") ) {
agg = new AggregateOperator(0, Plus.getPlusFnObject());
}
else if ( opcode.equalsIgnoreCase("a*") ) {
agg = new AggregateOperator(1, Multiply.getMultiplyFnObject());
}
else if (opcode.equalsIgnoreCase("arimax")){
agg = new AggregateOperator(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject("maxindex"), CorrectionLocationType.LASTCOLUMN);
}
else if ( opcode.equalsIgnoreCase("amax") ) {
agg = new AggregateOperator(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject("max"));
}
else if ( opcode.equalsIgnoreCase("amin") ) {
agg = new AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject("min"));
}
else if (opcode.equalsIgnoreCase("arimin")){
agg = new AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject("minindex"), CorrectionLocationType.LASTCOLUMN);
}
else if ( opcode.equalsIgnoreCase("amean") ) {
CorrectionLocationType lcorrLoc = (corrLoc==null) ? CorrectionLocationType.LASTTWOCOLUMNS : CorrectionLocationType.valueOf(corrLoc);
agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), lcorrLoc);
}
else if ( opcode.equalsIgnoreCase("avar") ) {
CorrectionLocationType lcorrLoc = (corrLoc==null) ?
CorrectionLocationType.LASTFOURCOLUMNS :
CorrectionLocationType.valueOf(corrLoc);
CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
agg = new AggregateOperator(0, varFn, lcorrLoc);
}
return agg;
}
public static AggregateUnaryOperator parseBasicCumulativeAggregateUnaryOperator(UnaryOperator uop) {
Builtin f = (Builtin)uop.fn;
if( f.getBuiltinCode()==BuiltinCode.CUMSUM )
return parseBasicAggregateUnaryOperator("uack+") ;
else if( f.getBuiltinCode()==BuiltinCode.CUMPROD )
return parseBasicAggregateUnaryOperator("uac*") ;
else if( f.getBuiltinCode()==BuiltinCode.CUMMIN )
return parseBasicAggregateUnaryOperator("uacmin") ;
else if( f.getBuiltinCode()==BuiltinCode.CUMMAX )
return parseBasicAggregateUnaryOperator("uacmax" ) ;
else if( f.getBuiltinCode()==BuiltinCode.CUMSUMPROD )
return parseBasicAggregateUnaryOperator("uack+*" ) ;
throw new RuntimeException("Unsupported cumulative aggregate unary operator: "+f.getBuiltinCode());
}
public static AggregateUnaryOperator parseCumulativeAggregateUnaryOperator(String opcode) {
AggregateOperator agg = null;
if( "ucumack+".equals(opcode) )
agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.LASTROW);
else if ( "ucumac*".equals(opcode) )
agg = new AggregateOperator(1, Multiply.getMultiplyFnObject(), CorrectionLocationType.NONE);
else if ( "ucumac+*".equals(opcode) )
agg = new AggregateOperator(0, PlusMultiply.getFnObject(), CorrectionLocationType.NONE);
else if ( "ucumacmin".equals(opcode) )
agg = new AggregateOperator(0, Builtin.getBuiltinFnObject("min"), CorrectionLocationType.NONE);
else if ( "ucumacmax".equals(opcode) )
agg = new AggregateOperator(0, Builtin.getBuiltinFnObject("max"), CorrectionLocationType.NONE);
return new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
public static UnaryOperator parseUnaryOperator(String opcode) {
return opcode.equals("!") ?
new UnaryOperator(Not.getNotFnObject()) :
new UnaryOperator(Builtin.getBuiltinFnObject(opcode));
}
public static Operator parseBinaryOrBuiltinOperator(String opcode, CPOperand in1, CPOperand in2) {
if( LibCommonsMath.isSupportedMatrixMatrixOperation(opcode) )
return null;
boolean matrixScalar = (in1.getDataType() != in2.getDataType());
return Builtin.isBuiltinFnObject(opcode) ?
(matrixScalar ? new RightScalarOperator( Builtin.getBuiltinFnObject(opcode), 0) :
new BinaryOperator( Builtin.getBuiltinFnObject(opcode))) :
(matrixScalar ? parseScalarBinaryOperator(opcode, in1.getDataType().isScalar()) :
parseBinaryOperator(opcode));
}
public static Operator parseExtendedBinaryOrBuiltinOperator(String opcode, CPOperand in1, CPOperand in2) {
boolean matrixScalar = (in1.getDataType() != in2.getDataType() && (in1.getDataType() != Types.DataType.FRAME && in2.getDataType() != Types.DataType.FRAME));
return Builtin.isBuiltinFnObject(opcode) ?
(matrixScalar ? new RightScalarOperator( Builtin.getBuiltinFnObject(opcode), 0) :
new BinaryOperator( Builtin.getBuiltinFnObject(opcode))) :
(matrixScalar ? parseScalarBinaryOperator(opcode, in1.getDataType().isScalar()) :
parseExtendedBinaryOperator(opcode));
}
public static BinaryOperator parseBinaryOperator(String opcode)
{
if(opcode.equalsIgnoreCase("=="))
return new BinaryOperator(Equals.getEqualsFnObject());
else if(opcode.equalsIgnoreCase("!="))
return new BinaryOperator(NotEquals.getNotEqualsFnObject());
else if(opcode.equalsIgnoreCase("<"))
return new BinaryOperator(LessThan.getLessThanFnObject());
else if(opcode.equalsIgnoreCase(">"))
return new BinaryOperator(GreaterThan.getGreaterThanFnObject());
else if(opcode.equalsIgnoreCase("<="))
return new BinaryOperator(LessThanEquals.getLessThanEqualsFnObject());
else if(opcode.equalsIgnoreCase(">="))
return new BinaryOperator(GreaterThanEquals.getGreaterThanEqualsFnObject());
else if(opcode.equalsIgnoreCase("&&"))
return new BinaryOperator(And.getAndFnObject());
else if(opcode.equalsIgnoreCase("||"))
return new BinaryOperator(Or.getOrFnObject());
else if(opcode.equalsIgnoreCase("xor"))
return new BinaryOperator(Xor.getXorFnObject());
else if(opcode.equalsIgnoreCase("bitwAnd"))
return new BinaryOperator(BitwAnd.getBitwAndFnObject());
else if(opcode.equalsIgnoreCase("bitwOr"))
return new BinaryOperator(BitwOr.getBitwOrFnObject());
else if(opcode.equalsIgnoreCase("bitwXor"))
return new BinaryOperator(BitwXor.getBitwXorFnObject());
else if(opcode.equalsIgnoreCase("bitwShiftL"))
return new BinaryOperator(BitwShiftL.getBitwShiftLFnObject());
else if(opcode.equalsIgnoreCase("bitwShiftR"))
return new BinaryOperator(BitwShiftR.getBitwShiftRFnObject());
else if(opcode.equalsIgnoreCase("+"))
return new BinaryOperator(Plus.getPlusFnObject());
else if(opcode.equalsIgnoreCase("-"))
return new BinaryOperator(Minus.getMinusFnObject());
else if(opcode.equalsIgnoreCase("*"))
return new BinaryOperator(Multiply.getMultiplyFnObject());
else if(opcode.equalsIgnoreCase("1-*"))
return new BinaryOperator(Minus1Multiply.getMinus1MultiplyFnObject());
else if ( opcode.equalsIgnoreCase("*2") )
return new BinaryOperator(Multiply2.getMultiply2FnObject());
else if(opcode.equalsIgnoreCase("/"))
return new BinaryOperator(Divide.getDivideFnObject());
else if(opcode.equalsIgnoreCase("%%"))
return new BinaryOperator(Modulus.getFnObject());
else if(opcode.equalsIgnoreCase("%/%"))
return new BinaryOperator(IntegerDivide.getFnObject());
else if(opcode.equalsIgnoreCase("^"))
return new BinaryOperator(Power.getPowerFnObject());
else if ( opcode.equalsIgnoreCase("^2") )
return new BinaryOperator(Power2.getPower2FnObject());
else if ( opcode.equalsIgnoreCase("max") )
return new BinaryOperator(Builtin.getBuiltinFnObject("max"));
else if ( opcode.equalsIgnoreCase("min") )
return new BinaryOperator(Builtin.getBuiltinFnObject("min"));
else if( opcode.equalsIgnoreCase("dropInvalidType"))
return new BinaryOperator(Builtin.getBuiltinFnObject("dropInvalidType"));
else if( opcode.equalsIgnoreCase("dropInvalidLength"))
return new BinaryOperator(Builtin.getBuiltinFnObject("dropInvalidLength"));
throw new RuntimeException("Unknown binary opcode " + opcode);
}
public static TernaryOperator parseTernaryOperator(String opcode) {
return new TernaryOperator(opcode.equals("+*") ? PlusMultiply.getFnObject() :
opcode.equals("-*") ? MinusMultiply.getFnObject() : IfElse.getFnObject());
}
/**
* scalar-matrix operator
*
* @param opcode the opcode
* @param arg1IsScalar ?
* @return scalar operator
*/
public static ScalarOperator parseScalarBinaryOperator(String opcode, boolean arg1IsScalar)
{
//for all runtimes that set constant dynamically (cp/spark)
double default_constant = 0;
return parseScalarBinaryOperator(opcode, arg1IsScalar, default_constant);
}
/**
* scalar-matrix operator
*
* @param opcode the opcode
* @param arg1IsScalar ?
* @param constant ?
* @return scalar operator
*/
public static ScalarOperator parseScalarBinaryOperator(String opcode, boolean arg1IsScalar, double constant)
{
//commutative operators
if ( opcode.equalsIgnoreCase("+") ){
return new RightScalarOperator(Plus.getPlusFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("*") ) {
return new RightScalarOperator(Multiply.getMultiplyFnObject(), constant);
}
//non-commutative operators
else if ( opcode.equalsIgnoreCase("-") ) {
if(arg1IsScalar)
return new LeftScalarOperator(Minus.getMinusFnObject(), constant);
else return new RightScalarOperator(Minus.getMinusFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("-nz") ) {
//no support for left scalar yet
return new RightScalarOperator(MinusNz.getMinusNzFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("/") ) {
if(arg1IsScalar)
return new LeftScalarOperator(Divide.getDivideFnObject(), constant);
else return new RightScalarOperator(Divide.getDivideFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("%%") ) {
if(arg1IsScalar)
return new LeftScalarOperator(Modulus.getFnObject(), constant);
else return new RightScalarOperator(Modulus.getFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("%/%") ) {
if(arg1IsScalar)
return new LeftScalarOperator(IntegerDivide.getFnObject(), constant);
else return new RightScalarOperator(IntegerDivide.getFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("^") ){
if(arg1IsScalar)
return new LeftScalarOperator(Power.getPowerFnObject(), constant);
else return new RightScalarOperator(Power.getPowerFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("max") ) {
return new RightScalarOperator(Builtin.getBuiltinFnObject("max"), constant);
}
else if ( opcode.equalsIgnoreCase("min") ) {
return new RightScalarOperator(Builtin.getBuiltinFnObject("min"), constant);
}
else if ( opcode.equalsIgnoreCase("log") || opcode.equalsIgnoreCase("log_nz") ){
if( arg1IsScalar )
return new LeftScalarOperator(Builtin.getBuiltinFnObject(opcode), constant);
return new RightScalarOperator(Builtin.getBuiltinFnObject(opcode), constant);
}
else if ( opcode.equalsIgnoreCase(">") ) {
if(arg1IsScalar)
return new LeftScalarOperator(GreaterThan.getGreaterThanFnObject(), constant);
return new RightScalarOperator(GreaterThan.getGreaterThanFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase(">=") ) {
if(arg1IsScalar)
return new LeftScalarOperator(GreaterThanEquals.getGreaterThanEqualsFnObject(), constant);
return new RightScalarOperator(GreaterThanEquals.getGreaterThanEqualsFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("<") ) {
if(arg1IsScalar)
return new LeftScalarOperator(LessThan.getLessThanFnObject(), constant);
return new RightScalarOperator(LessThan.getLessThanFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("<=") ) {
if(arg1IsScalar)
return new LeftScalarOperator(LessThanEquals.getLessThanEqualsFnObject(), constant);
return new RightScalarOperator(LessThanEquals.getLessThanEqualsFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("==") ) {
if(arg1IsScalar)
return new LeftScalarOperator(Equals.getEqualsFnObject(), constant);
return new RightScalarOperator(Equals.getEqualsFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("!=") ) {
if(arg1IsScalar)
return new LeftScalarOperator(NotEquals.getNotEqualsFnObject(), constant);
return new RightScalarOperator(NotEquals.getNotEqualsFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("&&") ) {
return arg1IsScalar ?
new LeftScalarOperator(And.getAndFnObject(), constant) :
new RightScalarOperator(And.getAndFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("||") ) {
return arg1IsScalar ?
new LeftScalarOperator(Or.getOrFnObject(), constant) :
new RightScalarOperator(Or.getOrFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("xor") ) {
return arg1IsScalar ?
new LeftScalarOperator(Xor.getXorFnObject(), constant) :
new RightScalarOperator(Xor.getXorFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("bitwAnd") ) {
return arg1IsScalar ?
new LeftScalarOperator(BitwAnd.getBitwAndFnObject(), constant) :
new RightScalarOperator(BitwAnd.getBitwAndFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("bitwOr") ) {
return arg1IsScalar ?
new LeftScalarOperator(BitwOr.getBitwOrFnObject(), constant) :
new RightScalarOperator(BitwOr.getBitwOrFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("bitwXor") ) {
return arg1IsScalar ?
new LeftScalarOperator(BitwXor.getBitwXorFnObject(), constant) :
new RightScalarOperator(BitwXor.getBitwXorFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("bitwShiftL") ) {
return arg1IsScalar ?
new LeftScalarOperator(BitwShiftL.getBitwShiftLFnObject(), constant) :
new RightScalarOperator(BitwShiftL.getBitwShiftLFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("bitwShiftR") ) {
return arg1IsScalar ?
new LeftScalarOperator(BitwShiftR.getBitwShiftRFnObject(), constant) :
new RightScalarOperator(BitwShiftR.getBitwShiftRFnObject(), constant);
}
//operations that only exist for performance purposes (all unary or commutative operators)
else if ( opcode.equalsIgnoreCase("*2") ) {
return new RightScalarOperator(Multiply2.getMultiply2FnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("^2") ){
return new RightScalarOperator(Power2.getPower2FnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("1-*") ) {
return new RightScalarOperator(Minus1Multiply.getMinus1MultiplyFnObject(), constant);
}
//operations that only exist in mr
else if ( opcode.equalsIgnoreCase("s-r") ) {
return new LeftScalarOperator(Minus.getMinusFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("so") ) {
return new LeftScalarOperator(Divide.getDivideFnObject(), constant);
}
throw new RuntimeException("Unknown binary opcode " + opcode);
}
public static BinaryOperator parseExtendedBinaryOperator(String opcode) {
if(opcode.equalsIgnoreCase("==") || opcode.equalsIgnoreCase("map=="))
return new BinaryOperator(Equals.getEqualsFnObject());
else if(opcode.equalsIgnoreCase("!=") || opcode.equalsIgnoreCase("map!="))
return new BinaryOperator(NotEquals.getNotEqualsFnObject());
else if(opcode.equalsIgnoreCase("<") || opcode.equalsIgnoreCase("map<"))
return new BinaryOperator(LessThan.getLessThanFnObject());
else if(opcode.equalsIgnoreCase(">") || opcode.equalsIgnoreCase("map>"))
return new BinaryOperator(GreaterThan.getGreaterThanFnObject());
else if(opcode.equalsIgnoreCase("<=") || opcode.equalsIgnoreCase("map<="))
return new BinaryOperator(LessThanEquals.getLessThanEqualsFnObject());
else if(opcode.equalsIgnoreCase(">=") || opcode.equalsIgnoreCase("map>="))
return new BinaryOperator(GreaterThanEquals.getGreaterThanEqualsFnObject());
else if(opcode.equalsIgnoreCase("&&") || opcode.equalsIgnoreCase("map&&"))
return new BinaryOperator(And.getAndFnObject());
else if(opcode.equalsIgnoreCase("||") || opcode.equalsIgnoreCase("map||"))
return new BinaryOperator(Or.getOrFnObject());
else if(opcode.equalsIgnoreCase("xor") || opcode.equalsIgnoreCase("mapxor"))
return new BinaryOperator(Xor.getXorFnObject());
else if(opcode.equalsIgnoreCase("bitwAnd") || opcode.equalsIgnoreCase("mapbitwAnd"))
return new BinaryOperator(BitwAnd.getBitwAndFnObject());
else if(opcode.equalsIgnoreCase("bitwOr") || opcode.equalsIgnoreCase("mapbitwOr"))
return new BinaryOperator(BitwOr.getBitwOrFnObject());
else if(opcode.equalsIgnoreCase("bitwXor") || opcode.equalsIgnoreCase("mapbitwXor"))
return new BinaryOperator(BitwXor.getBitwXorFnObject());
else if(opcode.equalsIgnoreCase("bitwShiftL") || opcode.equalsIgnoreCase("mapbitwShiftL"))
return new BinaryOperator(BitwShiftL.getBitwShiftLFnObject());
else if(opcode.equalsIgnoreCase("bitwShiftR") || opcode.equalsIgnoreCase("mapbitwShiftR"))
return new BinaryOperator(BitwShiftR.getBitwShiftRFnObject());
else if(opcode.equalsIgnoreCase("+") || opcode.equalsIgnoreCase("map+"))
return new BinaryOperator(Plus.getPlusFnObject());
else if(opcode.equalsIgnoreCase("-") || opcode.equalsIgnoreCase("map-"))
return new BinaryOperator(Minus.getMinusFnObject());
else if(opcode.equalsIgnoreCase("*") || opcode.equalsIgnoreCase("map*"))
return new BinaryOperator(Multiply.getMultiplyFnObject());
else if(opcode.equalsIgnoreCase("1-*") || opcode.equalsIgnoreCase("map1-*"))
return new BinaryOperator(Minus1Multiply.getMinus1MultiplyFnObject());
else if ( opcode.equalsIgnoreCase("*2") )
return new BinaryOperator(Multiply2.getMultiply2FnObject());
else if(opcode.equalsIgnoreCase("/") || opcode.equalsIgnoreCase("map/"))
return new BinaryOperator(Divide.getDivideFnObject());
else if(opcode.equalsIgnoreCase("%%") || opcode.equalsIgnoreCase("map%%"))
return new BinaryOperator(Modulus.getFnObject());
else if(opcode.equalsIgnoreCase("%/%") || opcode.equalsIgnoreCase("map%/%"))
return new BinaryOperator(IntegerDivide.getFnObject());
else if(opcode.equalsIgnoreCase("^") || opcode.equalsIgnoreCase("map^"))
return new BinaryOperator(Power.getPowerFnObject());
else if ( opcode.equalsIgnoreCase("^2") )
return new BinaryOperator(Power2.getPower2FnObject());
else if ( opcode.equalsIgnoreCase("max") || opcode.equalsIgnoreCase("mapmax") )
return new BinaryOperator(Builtin.getBuiltinFnObject("max"));
else if ( opcode.equalsIgnoreCase("min") || opcode.equalsIgnoreCase("mapmin") )
return new BinaryOperator(Builtin.getBuiltinFnObject("min"));
else if ( opcode.equalsIgnoreCase("dropInvalidLength") || opcode.equalsIgnoreCase("mapdropInvalidLength") )
return new BinaryOperator(Builtin.getBuiltinFnObject("dropInvalidLength"));
throw new DMLRuntimeException("Unknown binary opcode " + opcode);
}
public static String deriveAggregateOperatorOpcode(String opcode) {
switch( opcode ) {
case "uak+":
case"uark+":
case "uack+": return "ak+";
case "ua+":
case "uar+":
case "uac+": return "a+";
case "uatrace":
case "uaktrace": return "aktrace";
case "uasqk+":
case "uarsqk+":
case "uacsqk+": return "asqk+";
case "uamean":
case "uarmean":
case "uacmean": return "amean";
case "uavar":
case "uarvar":
case "uacvar": return "avar";
case "ua*":
case "uar*":
case "uac*": return "a*";
case "uamax":
case "uarmax":
case "uacmax": return "amax";
case "uamin":
case "uarmin":
case "uacmin": return "amin";
case "uarimax": return "arimax";
case "uarimin": return "arimin";
}
return null;
}
public static AggOp getAggOp(String opcode) {
switch( opcode ) {
case "uak+":
case"uark+":
case "uack+":
case "ua+":
case "uar+":
case "uac+":
case "uatrace":
case "uaktrace": return AggOp.SUM;
case "uasqk+":
case "uarsqk+":
case "uacsqk+": return AggOp.SUM_SQ;
case "uamean":
case "uarmean":
case "uacmean": return AggOp.MEAN;
case "uavar":
case "uarvar":
case "uacvar": return AggOp.VAR;
case "ua*":
case "uar*":
case "uac*": return AggOp.PROD;
case "uamax":
case "uarmax":
case "uacmax": return AggOp.MAX;
case "uamin":
case "uarmin":
case "uacmin": return AggOp.MIN;
case "uarimax": return AggOp.MAXINDEX;
case "uarimin": return AggOp.MININDEX;
}
return null;
}
public static Direction getAggDirection(String opcode) {
switch( opcode ) {
case "uak+":
case "ua+":
case "uatrace":
case "uaktrace":
case "uasqk+":
case "uamean":
case "uavar":
case "ua*":
case "uamax":
case "uamin": return Direction.RowCol;
case"uark+":
case "uar+":
case "uarsqk+":
case "uarmean":
case "uar*":
case "uarmax":
case "uarmin":
case "uarimax":
case "uarimin": return Direction.Row;
case "uack+":
case "uac+":
case "uacsqk+":
case "uacmean":
case "uarvar":
case "uacvar":
case "uac*":
case "uacmax":
case "uacmin": return Direction.Col;
}
return null;
}
public static CorrectionLocationType deriveAggregateOperatorCorrectionLocation(String opcode)
{
if ( opcode.equalsIgnoreCase("uak+") || opcode.equalsIgnoreCase("uark+") ||
opcode.equalsIgnoreCase("uasqk+") || opcode.equalsIgnoreCase("uarsqk+") ||
opcode.equalsIgnoreCase("uatrace") || opcode.equalsIgnoreCase("uaktrace") )
return CorrectionLocationType.LASTCOLUMN;
else if ( opcode.equalsIgnoreCase("uack+") || opcode.equalsIgnoreCase("uacsqk+") )
return CorrectionLocationType.LASTROW;
else if ( opcode.equalsIgnoreCase("uamean") || opcode.equalsIgnoreCase("uarmean") )
return CorrectionLocationType.LASTTWOCOLUMNS;
else if ( opcode.equalsIgnoreCase("uacmean") )
return CorrectionLocationType.LASTTWOROWS;
else if ( opcode.equalsIgnoreCase("uavar") || opcode.equalsIgnoreCase("uarvar") )
return CorrectionLocationType.LASTFOURCOLUMNS;
else if ( opcode.equalsIgnoreCase("uacvar") )
return CorrectionLocationType.LASTFOURROWS;
else if (opcode.equalsIgnoreCase("uarimax") || opcode.equalsIgnoreCase("uarimin") )
return CorrectionLocationType.LASTCOLUMN;
return CorrectionLocationType.NONE;
}
public static boolean isDistQuaternaryOpcode(String opcode)
{
return WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opcode) //mapwsloss
|| WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode) //redwsloss
|| WeightedSigmoid.OPCODE.equalsIgnoreCase(opcode) //mapwsigmoid
|| WeightedSigmoidR.OPCODE.equalsIgnoreCase(opcode) //redwsigmoid
|| WeightedDivMM.OPCODE.equalsIgnoreCase(opcode) //mapwdivmm
|| WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode) //redwdivmm
|| WeightedCrossEntropy.OPCODE.equalsIgnoreCase(opcode) //mapwcemm
|| WeightedCrossEntropyR.OPCODE.equalsIgnoreCase(opcode) //redwcemm
|| WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) //mapwumm
|| WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode); //redwumm
}
public static AggregateBinaryOperator getMatMultOperator(int k) {
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
return new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg, k);
}
public static Operator parseGroupedAggOperator(String fn, String other) {
AggregateOperationTypes op = AggregateOperationTypes.INVALID;
if ( fn.equalsIgnoreCase("centralmoment") )
// in case of CM, we also need to pass "order"
op = CMOperator.getAggOpType(fn, other);
else
op = CMOperator.getAggOpType(fn, null);
switch(op) {
case SUM:
return new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.LASTCOLUMN);
case COUNT:
case MEAN:
case VARIANCE:
case CM2:
case CM3:
case CM4:
return new CMOperator(CM.getCMFnObject(op), op);
case INVALID:
default:
throw new DMLRuntimeException("Invalid Aggregate Operation in GroupedAggregateInstruction: " + op);
}
}
public static String replaceOperand(String instStr, int operand, String newValue) {
//split instruction and check for correctness
String[] parts = instStr.split(Lop.OPERAND_DELIMITOR);
if( operand >= parts.length )
throw new DMLRuntimeException("Operand position "
+ operand + " exceeds the length of the instruction.");
//replace and reconstruct string
parts[operand] = newValue;
return concatOperands(parts);
}
public static String replaceOperandName(String instStr) {
String[] parts = instStr.split(Lop.OPERAND_DELIMITOR);
String oldName = parts[parts.length-1];
String[] Nameparts = oldName.split(Instruction.VALUETYPE_PREFIX);
Nameparts[0] = "xxx";
String newName = concatOperandParts(Nameparts);
parts[parts.length-1] = newName;
return concatOperands(parts);
}
public static String concatOperands(String... inputs) {
return concatOperandsWithDelim(Lop.OPERAND_DELIMITOR, inputs);
}
public static String concatOperandParts(String... inputs) {
return concatOperandsWithDelim(Instruction.VALUETYPE_PREFIX, inputs);
}
private static String concatOperandsWithDelim(String delim, String... inputs) {
StringBuilder sb = _strBuilders.get();
sb.setLength(0); //reuse allocated space
for( int i=0; i<inputs.length-1; i++ ) {
sb.append(inputs[i]);
sb.append(delim);
}
sb.append(inputs[inputs.length-1]);
return sb.toString();
}
public static String concatStrings(String... inputs) {
StringBuilder sb = _strBuilders.get();
sb.setLength(0); //reuse allocated space
for( int i=0; i<inputs.length; i++ )
sb.append(inputs[i]);
return sb.toString();
}
}