blob: 1e2e6ea2155cbfe5f5d44c5bd2d5e51ac7552e4a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.controlprogram.federated;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.log4j.Logger;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.matrix.data.LibMatrixAgg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.privacy.DMLPrivacyException;
import org.apache.sysds.runtime.privacy.PrivacyMonitor;
import org.apache.sysds.runtime.privacy.PrivacyPropagator;
import org.apache.sysds.utils.JSONHelper;
import org.apache.wink.json4j.JSONObject;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.Map;
public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
protected static Logger log = Logger.getLogger(FederatedWorkerHandler.class);
private final IDSequence _seq;
private Map<Long, Data> _vars;
public FederatedWorkerHandler(IDSequence seq, Map<Long, Data> _vars2) {
_seq = seq;
_vars = _vars2;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
log.debug("Received: " + msg.getClass().getSimpleName());
FederatedRequest request;
if (msg instanceof FederatedRequest)
request = (FederatedRequest) msg;
else
throw new DMLRuntimeException("FederatedWorkerHandler: Received object no instance of `FederatedRequest`.");
FederatedRequest.FedMethod method = request.getMethod();
log.debug("Received command: " + method.name());
PrivacyMonitor.setCheckPrivacy(request.checkPrivacy());
PrivacyMonitor.clearCheckedConstraints();
synchronized (_seq) {
FederatedResponse response = constructResponse(request);
conditionalAddCheckedConstraints(request, response);
if (!response.isSuccessful())
log.error("Method " + method + " failed: " + response.getErrorMessage());
ctx.writeAndFlush(response).addListener(new CloseListener());
}
}
private static void conditionalAddCheckedConstraints(FederatedRequest request, FederatedResponse response){
if ( request.checkPrivacy() )
response.setCheckedConstraints(PrivacyMonitor.getCheckedConstraints());
}
private FederatedResponse constructResponse(FederatedRequest request) {
FederatedRequest.FedMethod method = request.getMethod();
try {
switch (method) {
case READ_MATRIX:
return readData(request, Types.DataType.MATRIX);
case READ_FRAME:
return readData(request, Types.DataType.FRAME);
case MATVECMULT:
return executeMatVecMult(request);
case TRANSFER:
return getVariableData(request);
case AGGREGATE:
return executeAggregation(request);
case SCALAR:
return executeScalarOperation(request);
default:
String message = String.format("Method %s is not supported.", method);
return new FederatedResponse(FederatedResponse.Type.ERROR, new FederatedWorkerHandlerException(message));
}
}
catch (DMLPrivacyException | FederatedWorkerHandlerException exception) {
return new FederatedResponse(FederatedResponse.Type.ERROR, exception);
}
catch (Exception exception) {
return new FederatedResponse(FederatedResponse.Type.ERROR,
new FederatedWorkerHandlerException("Exception of type "
+ exception.getClass() + " thrown when processing request"));
}
}
private FederatedResponse readData(FederatedRequest request, Types.DataType dataType) {
checkNumParams(request.getNumParams(), 1);
String filename = (String) request.getParam(0);
return readData(filename, dataType);
}
private FederatedResponse readData(String filename, Types.DataType dataType) {
MatrixCharacteristics mc = new MatrixCharacteristics();
mc.setBlocksize(ConfigurationManager.getBlocksize());
CacheableData<?> cd;
switch (dataType) {
case MATRIX:
cd = new MatrixObject(Types.ValueType.FP64, filename);
break;
case FRAME:
cd = new FrameObject(filename);
break;
default:
// should NEVER happen (if we keep request codes in sync with actual behaviour)
return new FederatedResponse(FederatedResponse.Type.ERROR,
new FederatedWorkerHandlerException("Could not recognize datatype"));
}
// read metadata
FileFormat fmt = null;
try {
String mtdname = DataExpression.getMTDFileName(filename);
Path path = new Path(mtdname);
try (FileSystem fs = IOUtilFunctions.getFileSystem(mtdname)) {
try (BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) {
JSONObject mtd = JSONHelper.parse(br);
if (mtd == null)
return new FederatedResponse(FederatedResponse.Type.ERROR, new FederatedWorkerHandlerException("Could not parse metadata file"));
mc.setRows(mtd.getLong(DataExpression.READROWPARAM));
mc.setCols(mtd.getLong(DataExpression.READCOLPARAM));
cd = PrivacyPropagator.parseAndSetPrivacyConstraint(cd, mtd);
fmt = FileFormat.safeValueOf(mtd.getString(DataExpression.FORMAT_TYPE));
}
}
}
catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
cd.setMetaData(new MetaDataFormat(mc, fmt));
cd.acquireRead();
cd.refreshMetaData();
cd.release();
long id = _seq.getNextID();
_vars.put(id, cd);
if (dataType == Types.DataType.FRAME) {
FrameObject frameObject = (FrameObject) cd;
return new FederatedResponse(FederatedResponse.Type.SUCCESS, new Object[] {id, frameObject.getSchema()});
}
return new FederatedResponse(FederatedResponse.Type.SUCCESS, id);
}
private FederatedResponse executeMatVecMult(FederatedRequest request) {
checkNumParams(request.getNumParams(), 3);
MatrixBlock vector = (MatrixBlock) request.getParam(0);
boolean isMatVecMult = (Boolean) request.getParam(1);
long varID = (Long) request.getParam(2);
return executeMatVecMult(varID, vector, isMatVecMult);
}
private FederatedResponse executeMatVecMult(long varID, MatrixBlock vector, boolean isMatVecMult) {
MatrixObject matTo = (MatrixObject) _vars.get(varID);
matTo = PrivacyMonitor.handlePrivacy(matTo);
MatrixBlock matBlock1 = matTo.acquireReadAndRelease();
// TODO other datatypes
AggregateBinaryOperator ab_op = InstructionUtils
.getMatMultOperator(OptimizerUtils.getConstrainedNumThreads(0));
MatrixBlock result = isMatVecMult ?
matBlock1.aggregateBinaryOperations(matBlock1, vector, new MatrixBlock(), ab_op) :
vector.aggregateBinaryOperations(vector, matBlock1, new MatrixBlock(), ab_op);
return new FederatedResponse(FederatedResponse.Type.SUCCESS, result);
}
private FederatedResponse getVariableData(FederatedRequest request) {
checkNumParams(request.getNumParams(), 1);
long varID = (Long) request.getParam(0);
return getVariableData(varID);
}
private FederatedResponse getVariableData(long varID) {
Data dataObject = _vars.get(varID);
dataObject = PrivacyMonitor.handlePrivacy(dataObject);
switch (dataObject.getDataType()) {
case TENSOR:
return new FederatedResponse(FederatedResponse.Type.SUCCESS,
((TensorObject) dataObject).acquireReadAndRelease());
case MATRIX:
return new FederatedResponse(FederatedResponse.Type.SUCCESS,
((MatrixObject) dataObject).acquireReadAndRelease());
case FRAME:
return new FederatedResponse(FederatedResponse.Type.SUCCESS,
((FrameObject) dataObject).acquireReadAndRelease());
case LIST:
return new FederatedResponse(FederatedResponse.Type.SUCCESS, ((ListObject) dataObject).getData());
// TODO rest of the possible datatypes
default:
return new FederatedResponse(FederatedResponse.Type.ERROR,
new FederatedWorkerHandlerException("Not possible to send datatype " + dataObject.getDataType().name()));
}
}
private FederatedResponse executeAggregation(FederatedRequest request) {
checkNumParams(request.getNumParams(), 2);
AggregateUnaryOperator operator = (AggregateUnaryOperator) request.getParam(0);
long varID = (Long) request.getParam(1);
return executeAggregation(varID, operator);
}
private FederatedResponse executeAggregation(long varID, AggregateUnaryOperator operator) {
Data dataObject = _vars.get(varID);
if (dataObject.getDataType() != Types.DataType.MATRIX) {
return new FederatedResponse(FederatedResponse.Type.ERROR,
new FederatedWorkerHandlerException("Aggregation only supported for matrices, not for "
+ dataObject.getDataType().name()));
}
MatrixObject matrixObject = (MatrixObject) dataObject;
matrixObject = PrivacyMonitor.handlePrivacy(matrixObject);
MatrixBlock matrixBlock = matrixObject.acquireRead();
// create matrix for calculation with correction
MatrixCharacteristics mc = new MatrixCharacteristics();
// find out the characteristics after aggregation
operator.indexFn.computeDimension(matrixObject.getDataCharacteristics(), mc);
// make outBlock right size
int outNumRows = (int) mc.getRows();
int outNumCols = (int) mc.getCols();
if (operator.aggOp.existsCorrection()) {
// add rows for correction
int numMissing = operator.aggOp.correction.getNumRemovedRowsColumns();
if (operator.aggOp.correction.isRows())
outNumRows += numMissing;
else
outNumCols += numMissing;
}
MatrixBlock ret = new MatrixBlock(outNumRows, outNumCols, operator.aggOp.initialValue);
LibMatrixAgg.aggregateUnaryMatrix(matrixBlock, ret, operator);
// result block without correction
ret.dropLastRowsOrColumns(operator.aggOp.correction);
return new FederatedResponse(FederatedResponse.Type.SUCCESS, ret);
}
private FederatedResponse executeScalarOperation(FederatedRequest request) {
checkNumParams(request.getNumParams(), 2);
ScalarOperator operator = (ScalarOperator) request.getParam(0);
long varID = (Long) request.getParam(1);
return executeScalarOperation(varID, operator);
}
private FederatedResponse executeScalarOperation(long varID, ScalarOperator operator) {
Data dataObject = _vars.get(varID);
dataObject = PrivacyMonitor.handlePrivacy(dataObject);
if (dataObject.getDataType() != Types.DataType.MATRIX) {
return new FederatedResponse(FederatedResponse.Type.ERROR,
new FederatedWorkerHandlerException("FederatedWorkerHandler: ScalarOperator dont support "
+ dataObject.getDataType().name()));
}
MatrixObject matrixObject = (MatrixObject) dataObject;
MatrixBlock inBlock = matrixObject.acquireRead();
MatrixBlock retBlock = inBlock.scalarOperations(operator, new MatrixBlock());
return new FederatedResponse(FederatedResponse.Type.SUCCESS, retBlock);
}
@SuppressWarnings("unused")
private FederatedResponse createMatrixObject(MatrixBlock result) {
MatrixObject resTo = new MatrixObject(Types.ValueType.FP64, OptimizerUtils.getUniqueTempFileName());
MetaDataFormat metadata = new MetaDataFormat(
new MatrixCharacteristics(result.getNumRows(), result.getNumColumns()), FileFormat.BINARY);
resTo.setMetaData(metadata);
resTo.acquireModify(result);
resTo.release();
long result_var = _seq.getNextID();
_vars.put(result_var, resTo);
return new FederatedResponse(FederatedResponse.Type.SUCCESS, result_var);
}
private static void checkNumParams(int actual, int... expected) {
if (Arrays.stream(expected).anyMatch(x -> x == actual))
return;
throw new DMLRuntimeException("FederatedWorkerHandler: Received wrong amount of params:" + " expected="
+ Arrays.toString(expected) + ", actual=" + actual);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
cause.printStackTrace();
ctx.close();
}
private static class CloseListener implements ChannelFutureListener {
@Override
public void operationComplete(ChannelFuture channelFuture) throws InterruptedException, DMLRuntimeException {
if (!channelFuture.isSuccess())
throw new DMLRuntimeException("Federated Worker Write failed");
PrivacyMonitor.clearCheckedConstraints();
channelFuture.channel().close().sync();
}
}
}