[SYSTEMDS-2634] Reduced number of RPCs calls in federated backend
This patch improves the performance of the federated runtime backend by
merging the execution and cleanup RPC request batches into a single
batch of requests. Since every batch returns only a single response, we
now carefully select the right get_var, error, or other responses to
return. Overall, this reduced the number of RPC calls by almost 2x and
removed unnecessary synchronization barriers.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 720534a..4d0d5d9 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -680,7 +680,7 @@
//clear federated matrix
if( _fedMapping != null )
- _fedMapping.cleanup(tid, _fedMapping.getID());
+ _fedMapping.execCleanup(tid, _fedMapping.getID());
// change object state EMPTY
setDirty(false);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index f4af303..0dcb846 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -91,10 +91,24 @@
PrivacyMonitor.setCheckPrivacy(request.checkPrivacy());
PrivacyMonitor.clearCheckedConstraints();
- response = executeCommand(request);
- conditionalAddCheckedConstraints(request, response);
- if (!response.isSuccessful()){
- log.error("Command " + request.getType() + " failed: " + response.getErrorMessage() + "full command: \n" + request.toString());
+ //execute command and handle privacy constraints
+ FederatedResponse tmp = executeCommand(request);
+ conditionalAddCheckedConstraints(request, tmp);
+
+ //select the response for the entire batch of requests
+ if (!tmp.isSuccessful()) {
+ log.error("Command " + request.getType() + " failed: "
+ + tmp.getErrorMessage() + "full command: \n" + request.toString());
+ response = (response == null || response.isSuccessful())
+ ? tmp : response; //return first error
+ }
+ else if( request.getType() == RequestType.GET_VAR ) {
+ if( response != null && response.isSuccessful() )
+ log.error("Multiple GET_VAR are not supported in single batch of requests.");
+ response = tmp; //return last get result
+ }
+ else if( response == null && i == requests.length-1 ) {
+ response = tmp; //return last
}
}
ctx.writeAndFlush(response).addListener(new CloseListener());
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index b272bf9..72d1196 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -170,7 +170,14 @@
return readResponses;
}
- public void cleanup(long tid, long... id) {
+ public FederatedRequest cleanup(long tid, long... id) {
+ FederatedRequest request = new FederatedRequest(RequestType.EXEC_INST, -1,
+ VariableCPInstruction.prepareRemoveInstruction(id).toString());
+ request.setTID(tid);
+ return request;
+ }
+
+ public void execCleanup(long tid, long... id) {
FederatedRequest request = new FederatedRequest(RequestType.EXEC_INST, -1,
VariableCPInstruction.prepareRemoveInstruction(id).toString());
request.setTID(tid);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index faae560..7df7c51 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -32,7 +32,6 @@
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
-import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.Mean;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 34caec2..c28a163 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -68,10 +68,10 @@
new CPOperand[]{input1, input2},
new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});
FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
+ FederatedRequest fr3 = mo2.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
//execute federated operations and aggregate
- Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2);
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
MatrixBlock ret = FederationUtils.aggAdd(tmp);
- mo2.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
ec.setMatrixOutput(output.getName(), ret);
}
else if(mo1.isFederated(FType.ROW)) { // MV + MM
@@ -81,16 +81,16 @@
new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
if( mo2.getNumColumns() == 1 ) { //MV
FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
//execute federated operations and aggregate
- Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
MatrixBlock ret = FederationUtils.rbind(tmp);
- mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
ec.setMatrixOutput(output.getName(), ret);
}
else { //MM
//execute federated operations and aggregate
- mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
- mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+ FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+ mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID(), mo2.getNumColumns()));
@@ -104,10 +104,10 @@
FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2}, new long[]{fr1[0].getID(), mo2.getFedMapping().getID()});
FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
//execute federated operations and aggregate
- Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+ Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
MatrixBlock ret = FederationUtils.aggAdd(tmp);
- mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
ec.setMatrixOutput(output.getName(), ret);
}
else { //other combinations
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index e87bf57..60fe40b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -55,19 +55,19 @@
public void processInstruction(ExecutionContext ec) {
AggregateUnaryOperator aop = (AggregateUnaryOperator) _optr;
MatrixObject in = ec.getMatrixObject(input1);
+ FederationMap map = in.getFedMapping();
//create federated commands for aggregation
FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()});
FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
+ FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
//execute federated commands and cleanups
- FederationMap map = in.getFedMapping();
- Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2);
+ Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2, fr3);
if( output.isScalar() )
ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, tmp));
else
ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, tmp, map));
- map.cleanup(getTID(), fr1.getID());
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index 63c2d71..bceb6ae 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -42,39 +42,36 @@
FederatedRequest fr2 = null;
if( mo2.isFederated() ) {
- if(mo1.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)){
+ if(mo1.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) {
fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});
mo1.getFedMapping().execute(getTID(), true, fr2);
-
- } else{
+ }
+ else {
throw new DMLRuntimeException("Matrix-matrix binary operations "
+ " with a federated right input are not supported yet.");
}
-
- }
+ }
else {
//matrix-matrix binary oFederatedRequest fr2 = null;perations -> lhs fed input -> fed output
-
if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) { //MV row vector
FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
new long[]{mo1.getFedMapping().getID(), fr1[0].getID()});
+ FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
//execute federated instruction and cleanup intermediates
- mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
- mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+ mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
}
else { //MM or MV col vector
FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
new long[]{mo1.getFedMapping().getID(), fr1.getID()});
+ FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID());
//execute federated instruction and cleanup intermediates
- mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
- mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+ mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
}
}
-
//derive new fed mapping for output
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getDataCharacteristics());
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
index 75bfe33..b6ea1fb 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
@@ -39,17 +39,20 @@
CPOperand scalar = input2.isScalar() ? input2 : input1;
MatrixObject mo = ec.getMatrixObject(matrix);
- //execute federated matrix-scalar operation and cleanups
+ //prepare federated request matrix-scalar
FederatedRequest fr1 = !scalar.isLiteral() ?
mo.getFedMapping().broadcast(ec.getScalarInput(scalar)) : null;
FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{matrix, (fr1 != null)?scalar:null},
new long[]{mo.getFedMapping().getID(), (fr1 != null)?fr1.getID():-1});
- mo.getFedMapping().execute(getTID(), true, (fr1!=null) ?
- new FederatedRequest[]{fr1, fr2}: new FederatedRequest[]{fr2});
- if( fr1 != null )
- mo.getFedMapping().cleanup(getTID(), fr1.getID());
+ //execute federated matrix-scalar operation and cleanups
+ if( fr1 != null ) {
+ FederatedRequest fr3 = mo.getFedMapping().cleanup(getTID(), fr1.getID());
+ mo.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+ }
+ else
+ mo.getFedMapping().execute(getTID(), true, fr2);
//derive new fed mapping for output
MatrixObject out = ec.getMatrixObject(output);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
index 2dee64b..99a305b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
@@ -86,11 +86,12 @@
FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo1.getFedMapping()
+ .cleanup(getTID(), fr1.getID(), fr2.getID());
//execute federated operations and aggregate
- Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
MatrixBlock ret = FederationUtils.aggAdd(tmp);
- mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
ec.setMatrixOutput(output.getName(), ret);
}
else { //XtwXv | XtXvy
@@ -101,11 +102,12 @@
new CPOperand[]{input1, input2, input3},
new long[]{mo1.getFedMapping().getID(), fr1.getID(), fr0[0].getID()});
FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo1.getFedMapping()
+ .cleanup(getTID(), fr0[0].getID(), fr1.getID(), fr2.getID());
//execute federated operations and aggregate
- Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr0, fr1, fr2, fr3);
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr0, fr1, fr2, fr3, fr4);
MatrixBlock ret = FederationUtils.aggAdd(tmp);
- mo1.getFedMapping().cleanup(getTID(), fr0[0].getID(), fr1.getID(), fr2.getID());
ec.setMatrixOutput(output.getName(), ret);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index 292bced..fbe88d6 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -67,11 +67,11 @@
FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1}, new long[]{mo1.getFedMapping().getID()});
FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
+ FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID());
//execute federated operations and aggregate
- Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2);
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
MatrixBlock ret = FederationUtils.aggAdd(tmp);
- mo1.getFedMapping().cleanup(getTID(), fr1.getID());
ec.setMatrixOutput(output.getName(), ret);
}
else { //other combinations