| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.runtime.controlprogram.federated; |
| |
| import java.util.Arrays; |
| import java.util.List; |
| import java.util.concurrent.Future; |
| |
| import org.apache.log4j.Logger; |
| import org.apache.sysds.common.Types.ExecType; |
| import org.apache.sysds.lops.Lop; |
| import org.apache.sysds.runtime.DMLRuntimeException; |
| import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType; |
| import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; |
| import org.apache.sysds.runtime.functionobjects.Builtin; |
| import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; |
| import org.apache.sysds.runtime.functionobjects.KahanFunction; |
| import org.apache.sysds.runtime.functionobjects.Mean; |
| import org.apache.sysds.runtime.functionobjects.Plus; |
| import org.apache.sysds.runtime.instructions.InstructionUtils; |
| import org.apache.sysds.runtime.instructions.cp.CPOperand; |
| import org.apache.sysds.runtime.instructions.cp.DoubleObject; |
| import org.apache.sysds.runtime.instructions.cp.ScalarObject; |
| import org.apache.sysds.runtime.matrix.data.MatrixBlock; |
| import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; |
| import org.apache.sysds.runtime.matrix.operators.BinaryOperator; |
| import org.apache.sysds.runtime.matrix.operators.ScalarOperator; |
| import org.apache.sysds.runtime.matrix.operators.SimpleOperator; |
| |
| public class FederationUtils { |
| protected static Logger log = Logger.getLogger(FederationUtils.class); |
| private static final IDSequence _idSeq = new IDSequence(); |
| |
| public static void resetFedDataID() { |
| _idSeq.reset(); |
| } |
| |
| public static long getNextFedDataID() { |
| return _idSeq.getNextID(); |
| } |
| |
| public static FederatedRequest callInstruction(String inst, CPOperand varOldOut, CPOperand[] varOldIn, long[] varNewIn) { |
| //TODO better and safe replacement of operand names --> instruction utils |
| long id = getNextFedDataID(); |
| String linst = inst.replace(ExecType.SPARK.name(), ExecType.CP.name()); |
| linst = linst.replace( |
| Lop.OPERAND_DELIMITOR+varOldOut.getName()+Lop.DATATYPE_PREFIX, |
| Lop.OPERAND_DELIMITOR+String.valueOf(id)+Lop.DATATYPE_PREFIX); |
| for(int i=0; i<varOldIn.length; i++) |
| if( varOldIn[i] != null ) { |
| linst = linst.replace( |
| Lop.OPERAND_DELIMITOR+varOldIn[i].getName()+Lop.DATATYPE_PREFIX, |
| Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i])+Lop.DATATYPE_PREFIX); |
| linst = linst.replace("="+varOldIn[i].getName(), "="+String.valueOf(varNewIn[i])); //parameterized |
| } |
| return new FederatedRequest(RequestType.EXEC_INST, id, linst); |
| } |
| |
| public static MatrixBlock aggAdd(Future<FederatedResponse>[] ffr) { |
| try { |
| SimpleOperator op = new SimpleOperator(Plus.getPlusFnObject()); |
| MatrixBlock[] in = new MatrixBlock[ffr.length]; |
| for(int i=0; i<ffr.length; i++) |
| in[i] = (MatrixBlock) ffr[i].get().getData()[0]; |
| return MatrixBlock.naryOperations(op, in, new ScalarObject[0], new MatrixBlock()); |
| } |
| catch(Exception ex) { |
| throw new DMLRuntimeException(ex); |
| } |
| } |
| |
| public static MatrixBlock aggMean(Future<FederatedResponse>[] ffr, FederationMap map) { |
| try { |
| FederatedRange[] ranges = map.getFederatedRanges(); |
| BinaryOperator bop = InstructionUtils.parseBinaryOperator("+"); |
| ScalarOperator sop1 = InstructionUtils.parseScalarBinaryOperator("*", false); |
| MatrixBlock ret = null; |
| long size = 0; |
| for(int i=0; i<ffr.length; i++) { |
| MatrixBlock tmp = (MatrixBlock)ffr[i].get().getData()[0]; |
| size += ranges[i].getSize(0); |
| sop1 = sop1.setConstant(ranges[i].getSize(0)); |
| tmp = tmp.scalarOperations(sop1, new MatrixBlock()); |
| ret = (ret==null) ? tmp : ret.binaryOperationsInPlace(bop, tmp); |
| } |
| ScalarOperator sop2 = InstructionUtils.parseScalarBinaryOperator("/", false); |
| sop2 = sop2.setConstant(size); |
| return ret.scalarOperations(sop2, new MatrixBlock()); |
| } |
| catch(Exception ex) { |
| throw new DMLRuntimeException(ex); |
| } |
| } |
| |
| public static DoubleObject aggMinMax(Future<FederatedResponse>[] ffr, boolean isMin, boolean isScalar) { |
| try { |
| double res = isMin ? Double.MAX_VALUE : - Double.MAX_VALUE; |
| for (Future<FederatedResponse> fr: ffr){ |
| double v = isScalar ? ((ScalarObject)fr.get().getData()[0]).getDoubleValue() : |
| isMin ? ((MatrixBlock) fr.get().getData()[0]).min() : ((MatrixBlock) fr.get().getData()[0]).max(); |
| res = isMin ? Math.min(res, v) : Math.max(res, v); |
| } |
| return new DoubleObject(res); |
| } |
| catch (Exception ex) { |
| throw new DMLRuntimeException(ex); |
| } |
| } |
| |
| public static MatrixBlock[] getResults(Future<FederatedResponse>[] ffr) { |
| try { |
| MatrixBlock[] ret = new MatrixBlock[ffr.length]; |
| for(int i=0; i<ffr.length; i++) |
| ret[i] = (MatrixBlock) ffr[i].get().getData()[0]; |
| return ret; |
| } |
| catch(Exception ex) { |
| throw new DMLRuntimeException(ex); |
| } |
| } |
| |
| public static MatrixBlock rbind(Future<FederatedResponse>[] ffr) { |
| // TODO handle non-contiguous cases |
| try { |
| MatrixBlock[] tmp = getResults(ffr); |
| return tmp[0].append( |
| Arrays.copyOfRange(tmp, 1, tmp.length), |
| new MatrixBlock(), false); |
| } |
| catch(Exception ex) { |
| throw new DMLRuntimeException(ex); |
| } |
| } |
| |
| public static ScalarObject aggScalar(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr) { |
| if(!(aop.aggOp.increOp.fn instanceof KahanFunction || (aop.aggOp.increOp.fn instanceof Builtin && |
| (((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN || |
| ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)))) { |
| throw new DMLRuntimeException("Unsupported aggregation operator: " |
| + aop.aggOp.increOp.getClass().getSimpleName()); |
| } |
| |
| try { |
| if(aop.aggOp.increOp.fn instanceof Builtin){ |
| // then we know it is a Min or Max based on the previous check. |
| boolean isMin = ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN; |
| return aggMinMax(ffr, isMin, true); |
| } |
| else { |
| double sum = 0; //uak+ |
| for( Future<FederatedResponse> fr : ffr ) |
| sum += ((ScalarObject)fr.get().getData()[0]).getDoubleValue(); |
| return new DoubleObject(sum); |
| } |
| } |
| catch(Exception ex) { |
| throw new DMLRuntimeException(ex); |
| } |
| } |
| |
| public static MatrixBlock aggMatrix(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr, FederationMap map) { |
| // handle row aggregate |
| if( aop.isRowAggregate() ) { |
| //independent of aggregation function for row-partitioned federated matrices |
| return rbind(ffr); |
| } |
| // handle col aggregate |
| if( aop.aggOp.increOp.fn instanceof KahanFunction ) |
| return aggAdd(ffr); |
| else if( aop.aggOp.increOp.fn instanceof Mean ) |
| return aggMean(ffr, map); |
| else if (aop.aggOp.increOp.fn instanceof Builtin && |
| (((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN || |
| ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)) { |
| boolean isMin = ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN; |
| return new MatrixBlock(1,1,aggMinMax(ffr, isMin, false).getDoubleValue()); |
| } |
| else |
| throw new DMLRuntimeException("Unsupported aggregation operator: " |
| + aop.aggOp.increOp.fn.getClass().getSimpleName()); |
| } |
| |
| public static void waitFor(List<Future<FederatedResponse>> responses) { |
| try { |
| for(Future<FederatedResponse> fr : responses) |
| fr.get(); |
| } |
| catch(Exception ex) { |
| throw new DMLRuntimeException(ex); |
| } |
| } |
| } |