[SYSTEMDS-2629,2630] Extended federated backend (r', perf, correctness)
This patch adds a federated transpose instruction, support for aligned
federated-federated matrix multiplication, and now explicitly checks and
maintains the partitioning scheme of federated matrices.
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index 653d7eb..74d4457 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -89,8 +89,8 @@
public static final String PRINT_GPU_MEMORY_INFO = "sysds.gpu.print.memoryInfo";
public static final String EVICTION_SHADOW_BUFFERSIZE = "sysds.gpu.eviction.shadow.bufferSize";
- public static final String DEFAULT_FEDERATED_PORT = "4040"; // borrowed default Spark Port
- public static final String DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS = "1";
+ public static final int DEFAULT_FEDERATED_PORT = 4040; // borrowed default Spark Port
+ public static final int DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS = 2;
//internal config
public static final String DEFAULT_SHARED_DIR_PERMISSION = "777"; //for local fs and DFS
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 949e60a..f7e893f 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -33,6 +33,7 @@
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer.RPolicy;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.instructions.cp.Data;
@@ -334,6 +335,10 @@
return _fedMapping != null;
}
+ public boolean isFederated(FType type) {
+ return isFederated() && _fedMapping.getType() == type;
+ }
+
/**
* Gets the mapping of indices ranges to federated objects.
* @return fedMapping mapping
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 2c5f902..296e6f2 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -49,7 +49,7 @@
* The ID of default matrix/tensor on which operations get executed if no other ID is given.
*/
private long _varID = -1; // -1 is never valid since varIDs start at 0
- private int _nrThreads = Integer.parseInt(DMLConfig.DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS);
+ private int _nrThreads = DMLConfig.DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS;
public FederatedData(Types.DataType dataType, InetSocketAddress address, String filepath) {
@@ -88,6 +88,11 @@
return _varID != -1;
}
+ boolean equalAddress(FederatedData that) {
+ return _address != null && that != null && that._address != null
+ && _address.equals(that._address);
+ }
+
public synchronized Future<FederatedResponse> initFederatedData(long id) {
if(isInitialized())
throw new DMLRuntimeException("Tried to init already initialized data");
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index 46ebce2..23d0269 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -109,4 +109,14 @@
_endDims[1] += cshift;
return this;
}
+
+ public FederatedRange transpose() {
+ long tmpBeg = _beginDims[0];
+ long tmpEnd = _endDims[0];
+ _beginDims[0] = _beginDims[1];
+ _endDims[0] = _endDims[1];
+ _beginDims[1] = tmpBeg;
+ _endDims[1] = tmpEnd;
+ return this;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
index dae75e4..c51254b 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
@@ -41,8 +41,7 @@
public FederatedWorker(int port) {
_ecm = new ExecutionContextMap();
- _port = (port == -1) ?
- Integer.parseInt(DMLConfig.DEFAULT_FEDERATED_PORT) : port;
+ _port = (port == -1) ? DMLConfig.DEFAULT_FEDERATED_PORT : port;
}
public void run() {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 00a8685..bb64acc 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -45,7 +45,6 @@
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
-import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
@@ -115,14 +114,15 @@
return execUDF(request);
default:
String message = String.format("Method %s is not supported.", method);
- return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new FederatedWorkerHandlerException(message));
+ return new FederatedResponse(ResponseType.ERROR,
+ new FederatedWorkerHandlerException(message));
}
}
catch (DMLPrivacyException | FederatedWorkerHandlerException ex) {
- return new FederatedResponse(FederatedResponse.ResponseType.ERROR, ex);
+ return new FederatedResponse(ResponseType.ERROR, ex);
}
catch (Exception ex) {
- return new FederatedResponse(FederatedResponse.ResponseType.ERROR,
+ return new FederatedResponse(ResponseType.ERROR,
new FederatedWorkerHandlerException("Exception of type "
+ ex.getClass() + " thrown when processing request", ex));
}
@@ -148,7 +148,7 @@
break;
default:
// should NEVER happen (if we keep request codes in sync with actual behaviour)
- return new FederatedResponse(FederatedResponse.ResponseType.ERROR,
+ return new FederatedResponse(ResponseType.ERROR,
new FederatedWorkerHandlerException("Could not recognize datatype"));
}
@@ -161,7 +161,8 @@
try (BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) {
JSONObject mtd = JSONHelper.parse(br);
if (mtd == null)
- return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new FederatedWorkerHandlerException("Could not parse metadata file"));
+ return new FederatedResponse(ResponseType.ERROR,
+ new FederatedWorkerHandlerException("Could not parse metadata file"));
mc.setRows(mtd.getLong(DataExpression.READROWPARAM));
mc.setCols(mtd.getLong(DataExpression.READCOLPARAM));
cd = (CacheableData<?>) PrivacyPropagator.parseAndSetPrivacyConstraint(cd, mtd);
@@ -172,23 +173,21 @@
catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
- cd.setMetaData(new MetaDataFormat(mc, fmt));
- // TODO send FileFormatProperties with request and use them for CSV, this is currently a workaround so reading
- // of CSV files works
- cd.setFileFormatProperties(new FileFormatPropertiesCSV());
- cd.acquireRead();
- cd.refreshMetaData(); //in pinned state
- cd.release();
- //TODO spawn async load of data, otherwise on first access
- _ecm.get(tid).setVariable(String.valueOf(id), cd);
+ //put meta data object in symbol table, read on first operation
+ cd.setMetaData(new MetaDataFormat(mc, fmt));
cd.enableCleanup(false); //guard against deletion
+ _ecm.get(tid).setVariable(String.valueOf(id), cd);
if (dataType == Types.DataType.FRAME) {
FrameObject frameObject = (FrameObject) cd;
- return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[] {id, frameObject.getSchema()});
+ frameObject.acquireRead();
+ frameObject.refreshMetaData(); //get block schema
+ frameObject.release();
+ return new FederatedResponse(ResponseType.SUCCESS,
+ new Object[] {id, frameObject.getSchema()});
}
- return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, id);
+ return new FederatedResponse(ResponseType.SUCCESS, id);
}
private FederatedResponse putVariable(FederatedRequest request) {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 371c3ff..b25f8b9 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -43,26 +43,46 @@
public class FederationMap
{
+ public enum FType {
+ ROW, //row partitioned, groups of rows
+ COL, //column partitioned, groups of columns
+ OTHER,
+ }
+
private long _ID = -1;
private final Map<FederatedRange, FederatedData> _fedMap;
+ private FType _type;
public FederationMap(Map<FederatedRange, FederatedData> fedMap) {
this(-1, fedMap);
}
public FederationMap(long ID, Map<FederatedRange, FederatedData> fedMap) {
+ this(ID, fedMap, FType.OTHER);
+ }
+
+ public FederationMap(long ID, Map<FederatedRange, FederatedData> fedMap, FType type) {
_ID = ID;
_fedMap = fedMap;
+ _type = type;
}
public long getID() {
return _ID;
}
+ public FType getType() {
+ return _type;
+ }
+
public boolean isInitialized() {
return _ID >= 0;
}
+ public void setType(FType type) {
+ _type = type;
+ }
+
public FederatedRange[] getFederatedRanges() {
return _fedMap.keySet().toArray(new FederatedRange[0]);
}
@@ -96,6 +116,19 @@
return ret.toArray(new FederatedRequest[0]);
}
+ public boolean isAligned(FederationMap that, boolean transposed) {
+ //determines if the two federated data are aligned row/column partitions
+ //at the same federated site (which allows for purely federated operation)
+ boolean ret = true;
+ for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet()) {
+ FederatedRange range = !transposed ? e.getKey() :
+ new FederatedRange(e.getKey()).transpose();
+ FederatedData dat2 = that._fedMap.get(range);
+ ret &= e.getValue().equalAddress(dat2);
+ }
+ return ret;
+ }
+
public Future<FederatedResponse>[] execute(long tid, FederatedRequest... fr) {
return execute(tid, false, fr);
}
@@ -120,13 +153,9 @@
// prepare results (future federated responses), with optional wait to ensure the
// order of requests without data dependencies (e.g., cleanup RPCs)
- Future<FederatedResponse>[] ret2 = ret.toArray(new Future[0]);
- if( wait ) {
- Arrays.stream(ret2).forEach(e -> {
- try {e.get();} catch(Exception ex) {throw new DMLRuntimeException(ex);}
- });
- }
- return ret2;
+ if( wait )
+ waitFor(ret);
+ return ret.toArray(new Future[0]);
}
public List<Pair<FederatedRange, Future<FederatedResponse>>> requestFederatedData() {
@@ -145,8 +174,21 @@
FederatedRequest request = new FederatedRequest(RequestType.EXEC_INST, -1,
VariableCPInstruction.prepareRemoveInstruction(id).toString());
request.setTID(tid);
+ List<Future<FederatedResponse>> tmp = new ArrayList<>();
for(FederatedData fd : _fedMap.values())
- fd.executeFederatedOperation(request);
+ tmp.add(fd.executeFederatedOperation(request));
+ //wait to avoid interference w/ following requests
+ waitFor(tmp);
+ }
+
+ private static void waitFor(List<Future<FederatedResponse>> responses) {
+ try {
+ for(Future<FederatedResponse> fr : responses)
+ fr.get();
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
}
private static FederatedRequest[] addAll(FederatedRequest a, FederatedRequest[] b) {
@@ -164,7 +206,7 @@
//TODO handling of file path, but no danger as never written
for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() )
map.put(new FederatedRange(e.getKey()), new FederatedData(e.getValue(), id));
- return new FederationMap(id, map);
+ return new FederationMap(id, map, _type);
}
public FederationMap copyWithNewID(long id, long clen) {
@@ -183,6 +225,24 @@
}
return this;
}
+
+ public FederationMap transpose() {
+ Map<FederatedRange, FederatedData> tmp = new TreeMap<>(_fedMap);
+ _fedMap.clear();
+ for( Entry<FederatedRange, FederatedData> e : tmp.entrySet() ) {
+ _fedMap.put(
+ new FederatedRange(e.getKey()).transpose(),
+ new FederatedData(e.getValue(), _ID));
+ }
+ //derive output type
+ switch(_type) {
+ case ROW: _type = FType.COL; break;
+ case COL: _type = FType.ROW; break;
+ default: _type = FType.OTHER;
+ }
+ return this;
+ }
+
/**
* Execute a function for each <code>FederatedRange</code> + <code>FederatedData</code> pair. The function should
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 6fd6173..34caec2 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -26,6 +26,7 @@
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -61,7 +62,19 @@
MatrixObject mo2 = ec.getMatrixObject(input2);
//#1 federated matrix-vector multiplication
- if(mo1.isFederated()) { // MV + MM
+ if(mo1.isFederated(FType.COL) && mo2.isFederated(FType.ROW)
+ && mo1.getFedMapping().isAligned(mo2.getFedMapping(), true) ) {
+ FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2},
+ new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});
+ FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ mo2.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+ else if(mo1.isFederated(FType.ROW)) { // MV + MM
//construct commands: broadcast rhs, fed mv, retrieve results
FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
@@ -81,10 +94,11 @@
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID(), mo2.getNumColumns()));
+ out.getFedMapping().setType(FType.ROW);
}
}
//#2 vector - federated matrix multiplication
- else if (mo2.isFederated()) {// VM + MM
+ else if (mo2.isFederated(FType.ROW)) {// VM + MM
//construct commands: broadcast rhs, fed mv, retrieve results
FederatedRequest[] fr1 = mo2.getFedMapping().broadcastSliced(mo1, true);
FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
@@ -98,7 +112,8 @@
}
else { //other combinations
throw new DMLRuntimeException("Federated AggregateBinary not supported with the "
- + "following federated objects: "+mo1.isFederated()+" "+mo2.isFederated());
+ + "following federated objects: "+mo1.isFederated()+":"+mo1.getFedMapping()
+ +" "+mo2.isFederated()+":"+mo2.getFedMapping());
}
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
index 985d117..d17b7b5 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
@@ -23,6 +23,7 @@
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.OffsetColumnIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -76,7 +77,7 @@
+ " vs " + mo2.getNumColumns());
}
- if( mo1.isFederated() && _cbind ) {
+ if( mo1.isFederated(FType.ROW) && _cbind ) {
FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
@@ -87,7 +88,7 @@
dc1.getBlocksize(), dc1.getNonZeros()+dc2.getNonZeros());
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID()));
}
- else if( mo1.isFederated() && mo2.isFederated() && !_cbind ) {
+ else if( mo1.isFederated(FType.ROW) && mo2.isFederated(FType.ROW) && !_cbind ) {
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(dc1.getRows()+dc2.getRows(), dc1.getCols(),
dc1.getBlocksize(), dc1.getNonZeros()+dc2.getNonZeros());
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 77dedfd..292702e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -36,6 +36,7 @@
ParameterizedBuiltin,
Tsmm,
MMChain,
+ Reorg,
}
protected final FEDType _fedType;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index bbdaa8e..c8cf729 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -22,6 +22,7 @@
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.*;
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
@@ -40,7 +41,7 @@
if( instruction.input1.isMatrix() && instruction.input2.isMatrix() ) {
MatrixObject mo1 = ec.getMatrixObject(instruction.input1);
MatrixObject mo2 = ec.getMatrixObject(instruction.input2);
- if (mo1.isFederated() || mo2.isFederated()) {
+ if (mo1.isFederated(FType.ROW) || mo2.isFederated(FType.ROW)) {
fedinst = AggregateBinaryFEDInstruction.parseInstruction(inst.getInstructionString());
}
}
@@ -89,6 +90,12 @@
}
}
}
+ else if(inst instanceof ReorgCPInstruction && inst.getOpcode().equals("r'")) {
+ ReorgCPInstruction rinst = (ReorgCPInstruction) inst;
+ CacheableData<?> mo = ec.getCacheableData(rinst.input1);
+ if( mo.isFederated() )
+ fedinst = ReorgFEDInstruction.parseInstruction(rinst.getInstructionString());
+ }
//set thread id for federated context management
if( fedinst != null ) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 8d050b3..9ae5014 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -33,6 +33,7 @@
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
@@ -175,8 +176,8 @@
String ipRegex = "^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$";
if (host.matches("^\\d+\\.\\d+\\.\\d+\\.\\d+$") && !host.matches(ipRegex))
throw new IllegalArgumentException("Input Host address looks like an IP address but is outside range");
- String port = Integer.toString(address.getPort());
- if (port.equals("-1"))
+ int port = address.getPort();
+ if (port == -1)
port = DMLConfig.DEFAULT_FEDERATED_PORT;
String filePath = address.getPath();
if (filePath.length() <= 1)
@@ -193,7 +194,7 @@
if (address.getRef() != null)
throw new IllegalArgumentException("Reference is not supported");
- return new String[] { host, port, filePath };
+ return new String[] { host, String.valueOf(port), filePath };
}
catch (MalformedURLException e) {
throw new IllegalArgumentException("federated address `" + input
@@ -208,6 +209,8 @@
}
List<Pair<FederatedData, Future<FederatedResponse>>> idResponses = new ArrayList<>();
long id = FederationUtils.getNextFedDataID();
+ boolean rowPartitioned = true;
+ boolean colPartitioned = true;
for (Map.Entry<FederatedRange, FederatedData> entry : fedMapping.entrySet()) {
FederatedRange range = entry.getKey();
FederatedData value = entry.getValue();
@@ -215,25 +218,24 @@
long[] beginDims = range.getBeginDims();
long[] endDims = range.getEndDims();
long[] dims = output.getDataCharacteristics().getDims();
- for (int i = 0; i < dims.length; i++) {
+ for (int i = 0; i < dims.length; i++)
dims[i] = endDims[i] - beginDims[i];
- }
- // TODO check if all matrices have the same DataType (currently only double is supported)
idResponses.add(new ImmutablePair<>(value, value.initFederatedData(id)));
}
+ rowPartitioned &= (range.getSize(1) == output.getNumColumns());
+ colPartitioned &= (range.getSize(0) == output.getNumRows());
}
try {
- for (Pair<FederatedData, Future<FederatedResponse>> idResponse : idResponses) {
- FederatedResponse response = idResponse.getRight().get();
- idResponse.getLeft().setVarID((Long) response.getData()[0]);
- }
+ for (Pair<FederatedData, Future<FederatedResponse>> idResponse : idResponses)
+ idResponse.getRight().get(); //wait for initialization
}
catch (Exception e) {
throw new DMLRuntimeException("Federation initialization failed", e);
}
- output.getDataCharacteristics().setNonZeros(output.getNumColumns() * output.getNumRows());
+ output.getDataCharacteristics().setNonZeros(-1);
output.getDataCharacteristics().setBlocksize(ConfigurationManager.getBlocksize());
output.setFedMapping(new FederationMap(id, fedMapping));
+ output.getFedMapping().setType(rowPartitioned ? FType.ROW : colPartitioned ? FType.COL : FType.OTHER);
}
public void federateFrame(FrameObject output, List<Pair<FederatedRange, FederatedData>> workers) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
new file mode 100644
index 0000000..a4b604b
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+
+public class ReorgFEDInstruction extends UnaryFEDInstruction {
+
+ public ReorgFEDInstruction(CPOperand in1, CPOperand out, String opcode, String istr) {
+ super(FEDType.Reorg, null, in1, out, opcode, istr);
+ }
+
+ public static ReorgFEDInstruction parseInstruction ( String str ) {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+ if ( opcode.equalsIgnoreCase("r'") ) {
+ InstructionUtils.checkNumFields(str, 2, 3);
+ CPOperand in = new CPOperand(parts[1]);
+ CPOperand out = new CPOperand(parts[2]);
+ return new ReorgFEDInstruction(in, out, opcode, str);
+ }
+ else {
+ throw new DMLRuntimeException("ReorgFEDInstruction: unsupported opcode: "+opcode);
+ }
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ MatrixObject mo1 = ec.getMatrixObject(input1);
+
+ if( !mo1.isFederated() )
+ throw new DMLRuntimeException("Federated Reorg: "
+ + "Federated input expected, but invoked w/ "+mo1.isFederated());
+
+ //execute transpose at federated site
+ FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1}, new long[]{mo1.getFedMapping().getID()});
+ mo1.getFedMapping().execute(getTID(), true, fr1);
+
+ //drive output federated mapping
+ MatrixObject out = ec.getMatrixObject(output);
+ out.getDataCharacteristics().set(mo1.getNumColumns(),
+ mo1.getNumRows(), (int)mo1.getBlocksize(), mo1.getNnz());
+ out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose());
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/component/federated/FederatedUrlParserTest.java b/src/test/java/org/apache/sysds/test/component/federated/FederatedUrlParserTest.java
index 3a38c13..edcc477 100644
--- a/src/test/java/org/apache/sysds/test/component/federated/FederatedUrlParserTest.java
+++ b/src/test/java/org/apache/sysds/test/component/federated/FederatedUrlParserTest.java
@@ -161,7 +161,7 @@
@Test
public void checkDefaultPortIsValid() {
- int defaultPort = Integer.parseInt(DMLConfig.DEFAULT_FEDERATED_PORT);
+ int defaultPort = DMLConfig.DEFAULT_FEDERATED_PORT;
// The highest port number allowed.
int IANA_limit = 49152;
assertTrue(defaultPort <= IANA_limit);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
index aa88027..8125bfe 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
@@ -124,6 +124,7 @@
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
+ setOutputBuffering(false);
// we need the reference file to not be written to hdfs, so we get the correct format
rtplatform = Types.ExecMode.SINGLE_NODE;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
index 6991797..7da40ab 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
@@ -62,7 +62,9 @@
// rows have to be even and > 1
return Arrays.asList(new Object[][] {
{10000, 10, 1}, {2000, 50, 1}, {1000, 100, 1},
- {10000, 10, 4}, {2000, 50, 4}, {1000, 100, 4}, //concurrent requests
+ {10000, 10, 2}, {2000, 50, 2}, {1000, 100, 2}, //concurrent requests
+ //TODO more runs e.g., 16 -> but requires rework RPC framework first
+ //(e.g., see paramserv?)
});
}
@@ -127,6 +129,7 @@
Assert.assertTrue(heavyHittersContainsString("fed_+"));
Assert.assertTrue(heavyHittersContainsString("fed_<="));
Assert.assertTrue(heavyHittersContainsString("fed_/"));
+ Assert.assertTrue(heavyHittersContainsString("fed_r'"));
//check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedL2SVMTest.java
index e55cfc9..4cfc70e 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedL2SVMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedL2SVMTest.java
@@ -103,7 +103,8 @@
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
-
+ setOutputBuffering(false);
+
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-args", input("X1"), input("X2"), input("Y"), expected("Z")};
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
index 53eac1e..906b124 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
@@ -99,10 +99,11 @@
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
+ setOutputBuffering(false);
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-args", input("X1"), input("X2"),
+ programArgs = new String[] {"-stats", "-args", input("X1"), input("X2"),
String.valueOf(scaleAndShift).toUpperCase(), expected("Z")};
runTest(true, false, null, -1);
diff --git a/src/test/scripts/functions/federated/FederatedKmeansTest.dml b/src/test/scripts/functions/federated/FederatedKmeansTest.dml
index 95f136c..13e89ea 100644
--- a/src/test/scripts/functions/federated/FederatedKmeansTest.dml
+++ b/src/test/scripts/functions/federated/FederatedKmeansTest.dml
@@ -21,5 +21,5 @@
X = federated(addresses=list($in_X1, $in_X2),
ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), list($rows, $cols)))
-[C,Y] = kmeans(X=X, k=4, runs=$runs)
+[C,Y] = kmeans(X=X, k=4, runs=$runs, max_iter=150)
write(C, $out)
diff --git a/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml b/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
index da32c8b..e72c9b5 100644
--- a/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
@@ -20,5 +20,5 @@
#-------------------------------------------------------------
X = rbind(read($1), read($2))
-[C,Y] = kmeans(X=X, k=4, runs=$3)
+[C,Y] = kmeans(X=X, k=4, runs=$3, max_iter=150)
write(C, $4)