blob: e89579795efa5d796f6743fb62dd107aafc513fe [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.instructions;
import java.util.HashMap;
import org.apache.sysds.lops.RightIndex;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.gpu.AggregateBinaryGPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.AggregateUnaryGPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.ArithmeticBinaryGPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.BuiltinBinaryGPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.BuiltinUnaryGPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.DnnGPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.MMTSJGPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.MatrixAppendGPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.MatrixIndexingGPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.MatrixMatrixAxpyGPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.MatrixReshapeGPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.RelationalBinaryGPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.ReorgGPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction.GPUINSTRUCTION_TYPE;
public class GPUInstructionParser extends InstructionParser
{
static final HashMap<String, GPUINSTRUCTION_TYPE> String2GPUInstructionType;
static {
String2GPUInstructionType = new HashMap<>();
// Neural Network Operators
String2GPUInstructionType.put( "relu_backward", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "conv2d", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "conv2d_bias_add", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "conv2d_backward_filter", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "conv2d_backward_data", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "maxpooling", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "maxpooling_backward", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "avgpooling", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "avgpooling_backward", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "bias_add", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "bias_multiply", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "channel_sums", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "lstm", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "lstm_backward", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "batch_norm2d", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "batch_norm2d_backward", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "batch_norm2d_test", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "batch_norm2d_train", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "update_nesterov_x", GPUINSTRUCTION_TYPE.Dnn);
// Matrix Multiply Operators
String2GPUInstructionType.put( "ba+*", GPUINSTRUCTION_TYPE.AggregateBinary);
String2GPUInstructionType.put( "tsmm", GPUINSTRUCTION_TYPE.MMTSJ);
// Reorg/Transpose
String2GPUInstructionType.put( "r'", GPUINSTRUCTION_TYPE.Reorg);
String2GPUInstructionType.put( "rshape",GPUINSTRUCTION_TYPE.MatrixReshape);
// Matrix Manipulation
String2GPUInstructionType.put( "append", GPUINSTRUCTION_TYPE.Append);
// Binary Cellwise
String2GPUInstructionType.put( "+", GPUINSTRUCTION_TYPE.ArithmeticBinary);
String2GPUInstructionType.put( "-", GPUINSTRUCTION_TYPE.ArithmeticBinary);
String2GPUInstructionType.put( "*", GPUINSTRUCTION_TYPE.ArithmeticBinary);
String2GPUInstructionType.put( "/", GPUINSTRUCTION_TYPE.ArithmeticBinary);
String2GPUInstructionType.put( "%%", GPUINSTRUCTION_TYPE.ArithmeticBinary);
String2GPUInstructionType.put( "%/%", GPUINSTRUCTION_TYPE.ArithmeticBinary);
String2GPUInstructionType.put( "^", GPUINSTRUCTION_TYPE.ArithmeticBinary);
String2GPUInstructionType.put( "1-*", GPUINSTRUCTION_TYPE.ArithmeticBinary); //special * case
String2GPUInstructionType.put( "^2", GPUINSTRUCTION_TYPE.ArithmeticBinary); //special ^ case
String2GPUInstructionType.put( "*2", GPUINSTRUCTION_TYPE.ArithmeticBinary); //special * case
String2GPUInstructionType.put( "-nz", GPUINSTRUCTION_TYPE.ArithmeticBinary); //special - case
String2GPUInstructionType.put( "+*", GPUINSTRUCTION_TYPE.ArithmeticBinary);
String2GPUInstructionType.put( "-*", GPUINSTRUCTION_TYPE.ArithmeticBinary);
// Unary Builtin functions
String2GPUInstructionType.put( "exp", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "log", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "abs", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "sqrt", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "round", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "floor", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "ceil", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "sin", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "cos", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "tan", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "sinh", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "cosh", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "tanh", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "asin", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "acos", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "atan", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "sign", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "sigmoid", GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "softmax", GPUINSTRUCTION_TYPE.BuiltinUnary);
// Binary Builtin functions
String2GPUInstructionType.put( "solve", GPUINSTRUCTION_TYPE.BuiltinBinary);
String2GPUInstructionType.put( "min", GPUINSTRUCTION_TYPE.BuiltinBinary);
String2GPUInstructionType.put( "max", GPUINSTRUCTION_TYPE.BuiltinBinary);
// Aggregate Unary
String2GPUInstructionType.put( "ua+" , GPUINSTRUCTION_TYPE.AggregateUnary); // Sum
String2GPUInstructionType.put( "uak+" , GPUINSTRUCTION_TYPE.AggregateUnary); // Sum
String2GPUInstructionType.put( "uar+" , GPUINSTRUCTION_TYPE.AggregateUnary); // Row Sum
String2GPUInstructionType.put( "uark+" , GPUINSTRUCTION_TYPE.AggregateUnary); // Row Sum
String2GPUInstructionType.put( "uac+" , GPUINSTRUCTION_TYPE.AggregateUnary); // Col Sum
String2GPUInstructionType.put( "uack+" , GPUINSTRUCTION_TYPE.AggregateUnary); // Col Sum
String2GPUInstructionType.put( "ua*" , GPUINSTRUCTION_TYPE.AggregateUnary); // Multiplication
String2GPUInstructionType.put( "uamean" , GPUINSTRUCTION_TYPE.AggregateUnary); // Mean
String2GPUInstructionType.put( "uarmean" , GPUINSTRUCTION_TYPE.AggregateUnary); // Row Mean
String2GPUInstructionType.put( "uacmean" , GPUINSTRUCTION_TYPE.AggregateUnary); // Col Mean
String2GPUInstructionType.put( "uamax" , GPUINSTRUCTION_TYPE.AggregateUnary); // Max
String2GPUInstructionType.put( "uarmax" , GPUINSTRUCTION_TYPE.AggregateUnary); // Row Max
String2GPUInstructionType.put( "uacmax" , GPUINSTRUCTION_TYPE.AggregateUnary); // Col Max
String2GPUInstructionType.put( "uamin" , GPUINSTRUCTION_TYPE.AggregateUnary); // Min
String2GPUInstructionType.put( "uarmin" , GPUINSTRUCTION_TYPE.AggregateUnary); // Row Min
String2GPUInstructionType.put( "uacmin" , GPUINSTRUCTION_TYPE.AggregateUnary); // Col Min
String2GPUInstructionType.put( "uasqk+" , GPUINSTRUCTION_TYPE.AggregateUnary); // Sum of Squares
String2GPUInstructionType.put( "uarsqk+" , GPUINSTRUCTION_TYPE.AggregateUnary); // Row Sum of Squares
String2GPUInstructionType.put( "uacsqk+" , GPUINSTRUCTION_TYPE.AggregateUnary); // Col Sum of Squares
String2GPUInstructionType.put( "uavar" , GPUINSTRUCTION_TYPE.AggregateUnary); // Variance
String2GPUInstructionType.put( "uarvar" , GPUINSTRUCTION_TYPE.AggregateUnary); // Row Variance
String2GPUInstructionType.put( "uacvar" , GPUINSTRUCTION_TYPE.AggregateUnary); // Col Variance
// Cumulative Ops
String2GPUInstructionType.put( "ucumk+" , GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "ucum*" , GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "ucumk+*" , GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "ucummin" , GPUINSTRUCTION_TYPE.BuiltinUnary);
String2GPUInstructionType.put( "ucummax" , GPUINSTRUCTION_TYPE.BuiltinUnary);
// Relational Binary
String2GPUInstructionType.put( "==" , GPUINSTRUCTION_TYPE.RelationalBinary);
String2GPUInstructionType.put( "!=" , GPUINSTRUCTION_TYPE.RelationalBinary);
String2GPUInstructionType.put( "<" , GPUINSTRUCTION_TYPE.RelationalBinary);
String2GPUInstructionType.put( ">" , GPUINSTRUCTION_TYPE.RelationalBinary);
String2GPUInstructionType.put( "<=" , GPUINSTRUCTION_TYPE.RelationalBinary);
String2GPUInstructionType.put( ">=" , GPUINSTRUCTION_TYPE.RelationalBinary);
// Indexing
String2GPUInstructionType.put( RightIndex.OPCODE, GPUINSTRUCTION_TYPE.MatrixIndexing);
}
public static GPUInstruction parseSingleInstruction (String str ) {
if ( str == null || str.isEmpty() )
return null;
GPUINSTRUCTION_TYPE cptype = InstructionUtils.getGPUType(str);
if ( cptype == null )
throw new DMLRuntimeException("Unable derive cptype for instruction: " + str);
GPUInstruction cpinst = parseSingleInstruction(cptype, str);
if ( cpinst == null )
throw new DMLRuntimeException("Unable to parse instruction: " + str);
return cpinst;
}
public static GPUInstruction parseSingleInstruction ( GPUINSTRUCTION_TYPE gputype, String str ) {
if( str == null || str.isEmpty() )
return null;
if( gputype == null )
throw new DMLRuntimeException("The instruction is not GPU-enabled:" + str);
switch(gputype) {
case AggregateUnary:
return AggregateUnaryGPUInstruction.parseInstruction(str);
case AggregateBinary:
return AggregateBinaryGPUInstruction.parseInstruction(str);
case BuiltinUnary:
return BuiltinUnaryGPUInstruction.parseInstruction(str);
case BuiltinBinary:
return BuiltinBinaryGPUInstruction.parseInstruction(str);
case Append:
return MatrixAppendGPUInstruction.parseInstruction(str);
case Dnn:
return DnnGPUInstruction.parseInstruction(str);
case MMTSJ:
return MMTSJGPUInstruction.parseInstruction(str);
case Reorg:
return ReorgGPUInstruction.parseInstruction(str);
case MatrixReshape:
return MatrixReshapeGPUInstruction.parseInstruction(str);
case ArithmeticBinary:
String opcode = InstructionUtils.getOpCode(str);
if( opcode.equals("+*") || opcode.equals("-*") )
return MatrixMatrixAxpyGPUInstruction.parseInstruction(str);
else
return ArithmeticBinaryGPUInstruction.parseInstruction(str);
case RelationalBinary:
return RelationalBinaryGPUInstruction.parseInstruction(str);
case MatrixIndexing:
return MatrixIndexingGPUInstruction.parseInstruction(str);
default:
throw new DMLRuntimeException("Invalid GPU Instruction Type: " + gputype );
}
}
}