| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.runtime.instructions.fed; |
| |
| import java.util.HashMap; |
| import java.util.Map; |
| |
| import org.apache.sysds.common.Types; |
| import org.apache.sysds.runtime.DMLRuntimeException; |
| import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; |
| import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; |
| import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; |
| import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; |
| import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; |
| import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; |
| import org.apache.sysds.runtime.controlprogram.federated.FederationMap; |
| import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; |
| import org.apache.sysds.runtime.functionobjects.DiagIndex; |
| import org.apache.sysds.runtime.functionobjects.RevIndex; |
| import org.apache.sysds.runtime.functionobjects.SwapIndex; |
| import org.apache.sysds.runtime.instructions.InstructionUtils; |
| import org.apache.sysds.runtime.instructions.cp.CPOperand; |
| import org.apache.sysds.runtime.instructions.cp.Data; |
| import org.apache.sysds.runtime.matrix.data.MatrixBlock; |
| import org.apache.sysds.runtime.matrix.operators.Operator; |
| import org.apache.sysds.runtime.matrix.operators.ReorgOperator; |
| |
| public class ReorgFEDInstruction extends UnaryFEDInstruction { |
| |
| public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr) { |
| super(FEDType.Reorg, op, in1, out, opcode, istr); |
| } |
| |
| public static ReorgFEDInstruction parseInstruction ( String str ) { |
| CPOperand in = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN); |
| CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN); |
| |
| String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); |
| String opcode = parts[0]; |
| if ( opcode.equalsIgnoreCase("r'") ) { |
| InstructionUtils.checkNumFields(str, 2, 3); |
| in.split(parts[1]); |
| out.split(parts[2]); |
| int k = Integer.parseInt(parts[3]); |
| return new ReorgFEDInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str); |
| } |
| else if ( opcode.equalsIgnoreCase("rdiag") ) { |
| parseUnaryInstruction(str, in, out); //max 2 operands |
| return new ReorgFEDInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str); |
| } |
| else if ( opcode.equalsIgnoreCase("rev") ) { |
| parseUnaryInstruction(str, in, out); //max 2 operands |
| return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str); |
| } |
| else { |
| throw new DMLRuntimeException("ReorgFEDInstruction: unsupported opcode: "+opcode); |
| } |
| } |
| |
| @Override |
| public void processInstruction(ExecutionContext ec) { |
| MatrixObject mo1 = ec.getMatrixObject(input1); |
| ReorgOperator r_op = (ReorgOperator) _optr; |
| |
| if( !mo1.isFederated() ) |
| throw new DMLRuntimeException("Federated Reorg: " |
| + "Federated input expected, but invoked w/ "+mo1.isFederated()); |
| |
| if(instOpcode.equals("r'")) { |
| //execute transpose at federated site |
| FederatedRequest fr1 = FederationUtils.callInstruction(instString, |
| output, |
| new CPOperand[] {input1}, |
| new long[] {mo1.getFedMapping().getID()}); |
| mo1.getFedMapping().execute(getTID(), true, fr1); |
| |
| //drive output federated mapping |
| MatrixObject out = ec.getMatrixObject(output); |
| out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), (int) mo1.getBlocksize(), mo1.getNnz()); |
| out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose()); |
| } |
| else if(instOpcode.equalsIgnoreCase("rev")) { |
| //execute transpose at federated site |
| FederatedRequest fr1 = FederationUtils.callInstruction(instString, |
| output, |
| new CPOperand[] {input1}, |
| new long[] {mo1.getFedMapping().getID()}); |
| mo1.getFedMapping().execute(getTID(), true, fr1); |
| |
| if(mo1.isFederated(FederationMap.FType.ROW)) |
| mo1.getFedMapping().reverseFedMap(); |
| |
| //derive output federated mapping |
| MatrixObject out = ec.getMatrixObject(output); |
| out.getDataCharacteristics().set(mo1.getNumRows(), mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz()); |
| out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID())); |
| } |
| else if (instOpcode.equals("rdiag")) { |
| RdiagResult result; |
| // diag(diag(X)) |
| if (mo1.getNumColumns() == 1 && mo1.getNumRows() != 1) { |
| result = rdiagV2M(mo1, r_op); |
| } else { |
| result = rdiagM2V(mo1, r_op); |
| } |
| |
| FederationMap diagFedMap = result.getFedMap(); |
| Map<FederatedRange, int[]> dcs = result.getDcs(); |
| |
| //update fed ranges |
| for(int i = 0; i < diagFedMap.getFederatedRanges().length; i++) { |
| int[] newRange = dcs.get(diagFedMap.getFederatedRanges()[i]); |
| |
| diagFedMap.getFederatedRanges()[i].setBeginDim(0, |
| (diagFedMap.getFederatedRanges()[i].getBeginDims()[0] == 0 || |
| i == 0) ? 0 : diagFedMap.getFederatedRanges()[i - 1].getEndDims()[0]); |
| diagFedMap.getFederatedRanges()[i].setEndDim(0, |
| diagFedMap.getFederatedRanges()[i].getBeginDims()[0] + newRange[0]); |
| diagFedMap.getFederatedRanges()[i].setBeginDim(1, |
| (diagFedMap.getFederatedRanges()[i].getBeginDims()[1] == 0 || |
| i == 0) ? 0 : diagFedMap.getFederatedRanges()[i - 1].getEndDims()[1]); |
| diagFedMap.getFederatedRanges()[i].setEndDim(1, |
| diagFedMap.getFederatedRanges()[i].getBeginDims()[1] + newRange[1]); |
| } |
| |
| //update output mapping and data characteristics |
| MatrixObject rdiag = ec.getMatrixObject(output); |
| rdiag.getDataCharacteristics() |
| .set(diagFedMap.getMaxIndexInRange(0), diagFedMap.getMaxIndexInRange(1), |
| (int) mo1.getBlocksize()); |
| rdiag.setFedMapping(diagFedMap); |
| } |
| } |
| |
| private class RdiagResult { |
| FederationMap fedMap; |
| Map<FederatedRange, int[]> dcs; |
| |
| public RdiagResult(FederationMap fedMap, Map<FederatedRange, int[]> dcs) { |
| this.fedMap = fedMap; |
| this.dcs = dcs; |
| } |
| |
| public FederationMap getFedMap() { |
| return fedMap; |
| } |
| |
| public Map<FederatedRange, int[]> getDcs() { |
| return dcs; |
| } |
| } |
| |
| private RdiagResult rdiagV2M (MatrixObject mo1, ReorgOperator r_op) { |
| FederationMap fedMap = mo1.getFedMapping(); |
| boolean rowFed = mo1.isFederated(FederationMap.FType.ROW); |
| |
| long varID = FederationUtils.getNextFedDataID(); |
| Map<FederatedRange, int[]> dcs = new HashMap<>(); |
| FederationMap diagFedMap; |
| |
| diagFedMap = fedMap.mapParallel(varID, (range, data) -> { |
| try { |
| FederatedResponse response = data.executeFederatedOperation(new FederatedRequest( |
| FederatedRequest.RequestType.EXEC_UDF, -1, |
| new ReorgFEDInstruction.DiagMatrix(data.getVarID(), |
| varID, r_op, |
| rowFed ? (new int[] {range.getBeginDimsInt()[0], range.getEndDimsInt()[0]}) : |
| new int[] {range.getBeginDimsInt()[1], range.getEndDimsInt()[1]}, |
| rowFed, (int) mo1.getNumRows()))).get(); |
| if(!response.isSuccessful()) |
| response.throwExceptionFromResponse(); |
| int[] subRangeCharacteristics = (int[]) response.getData()[0]; |
| synchronized(dcs) { |
| dcs.put(range, subRangeCharacteristics); |
| } |
| return null; |
| } |
| catch(Exception e) { |
| throw new DMLRuntimeException(e); |
| } |
| }); |
| return new RdiagResult(diagFedMap, dcs); |
| } |
| |
| private RdiagResult rdiagM2V (MatrixObject mo1, ReorgOperator r_op) { |
| FederationMap fedMap = mo1.getFedMapping(); |
| boolean rowFed = mo1.isFederated(FederationMap.FType.ROW); |
| |
| long varID = FederationUtils.getNextFedDataID(); |
| Map<FederatedRange, int[]> dcs = new HashMap<>(); |
| FederationMap diagFedMap; |
| |
| diagFedMap = fedMap.mapParallel(varID, (range, data) -> { |
| try { |
| FederatedResponse response = data.executeFederatedOperation(new FederatedRequest( |
| FederatedRequest.RequestType.EXEC_UDF, -1, |
| new ReorgFEDInstruction.Rdiag(data.getVarID(), varID, r_op, |
| rowFed ? (new int[] {range.getBeginDimsInt()[0], range.getEndDimsInt()[0]}) : |
| new int[] {range.getBeginDimsInt()[1], range.getEndDimsInt()[1]}, |
| rowFed))).get(); |
| if(!response.isSuccessful()) |
| response.throwExceptionFromResponse(); |
| int[] subRangeCharacteristics = (int[]) response.getData()[0]; |
| synchronized(dcs) { |
| dcs.put(range, subRangeCharacteristics); |
| } |
| return null; |
| } |
| catch(Exception e) { |
| throw new DMLRuntimeException(e); |
| } |
| }); |
| return new RdiagResult(diagFedMap, dcs); |
| } |
| |
| private static class Rdiag extends FederatedUDF { |
| |
| private static final long serialVersionUID = -3466926635958851402L; |
| private final long _outputID; |
| private final ReorgOperator _r_op; |
| private final int[] _slice; |
| private final boolean _rowFed; |
| |
| private Rdiag(long input, long outputID, ReorgOperator r_op, int[] slice, boolean rowFed) { |
| super(new long[] {input}); |
| _outputID = outputID; |
| _r_op = r_op; |
| _slice = slice; |
| _rowFed = rowFed; |
| } |
| |
| @Override |
| public FederatedResponse execute(ExecutionContext ec, Data... data) { |
| MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease(); |
| MatrixBlock soresBlock; |
| MatrixBlock res; |
| |
| soresBlock = _rowFed ? |
| mb.slice(0, mb.getNumRows() - 1, _slice[0], _slice[1] - 1, new MatrixBlock()) : |
| mb.slice(_slice[0], _slice[1] - 1); |
| res = soresBlock.reorgOperations(_r_op, new MatrixBlock(), 0, 0, 0); |
| |
| MatrixObject mout = ExecutionContext.createMatrixObject(res); |
| mout.setDiag(true); |
| ec.setVariable(String.valueOf(_outputID), mout); |
| |
| return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new int[]{res.getNumRows(), res.getNumColumns()}); |
| } |
| } |
| |
| private static class DiagMatrix extends FederatedUDF { |
| |
| private static final long serialVersionUID = -3466926635958851402L; |
| private final long _outputID; |
| private final ReorgOperator _r_op; |
| private final int _len; |
| private final int[] _slice; |
| private final boolean _rowFed; |
| |
| private DiagMatrix(long input, long outputID, ReorgOperator r_op, int[] slice, boolean rowFed, int len) { |
| super(new long[] {input}); |
| _outputID = outputID; |
| _r_op = r_op; |
| _len = len; |
| _rowFed = rowFed; |
| _slice = slice; |
| } |
| |
| @Override |
| public FederatedResponse execute(ExecutionContext ec, Data... data) { |
| MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease(); |
| MatrixBlock res; |
| |
| MatrixBlock tmp = mb.reorgOperations(_r_op, new MatrixBlock(), 0, 0, 0); |
| if(_rowFed) { |
| res = new MatrixBlock(mb.getNumRows(), _len, 0.0); |
| res.copy(0, res.getNumRows()-1, _slice[0], _slice[1]-1, tmp, false); |
| } else { |
| res = new MatrixBlock(_len, _slice[1], 0.0); |
| res.copy(_slice[0], _slice[1]-1, 0, mb.getNumColumns() - 1, tmp, false);; |
| } |
| MatrixObject mout = ExecutionContext.createMatrixObject(res); |
| mout.setDiag(true); |
| ec.setVariable(String.valueOf(_outputID), mout); |
| |
| return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new int[]{res.getNumRows(), res.getNumColumns()}); |
| } |
| } |
| } |