[SYSTEMDS-2704] Add federated read 1 worker test
This commit adds a tests for one federated worker case, since this was
not tested before. Also a test case for Federated Y L2SVM is added for
a different number of workers.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 267ff9f..06de8f7 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -364,7 +364,8 @@
if(_fedMapping == null && _metaData instanceof MetaDataFormat){
MetaDataFormat mdf = (MetaDataFormat) _metaData;
if(mdf.getFileFormat() == FileFormat.FEDERATED){
- InitFEDInstruction.federateMatrix(this, ReaderWriterFederated.read(_hdfsFileName, mdf.getDataCharacteristics()));
+ InitFEDInstruction.federateMatrix(
+ this, ReaderWriterFederated.read(_hdfsFileName, mdf.getDataCharacteristics()));
return true;
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index e932785..10e679c 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.controlprogram.federated;
import java.io.BufferedReader;
+import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
@@ -66,11 +67,11 @@
protected static Logger log = Logger.getLogger(FederatedWorkerHandler.class);
private final ExecutionContextMap _ecm;
-
+
public FederatedWorkerHandler(ExecutionContextMap ecm) {
- //Note: federated worker handler created for every command;
- //and concurrent parfor threads at coordinator need separate
- //execution contexts at the federated sites too
+ // Note: federated worker handler created for every command;
+ // and concurrent parfor threads at coordinator need separate
+ // execution contexts at the federated sites too
_ecm = ecm;
}
@@ -80,46 +81,46 @@
}
public FederatedResponse createResponse(Object msg) {
- if( log.isDebugEnabled() ){
+ if(log.isDebugEnabled()) {
log.debug("Received: " + msg.getClass().getSimpleName());
}
- if (!(msg instanceof FederatedRequest[]))
- throw new DMLRuntimeException("FederatedWorkerHandler: Received object no instance of 'FederatedRequest[]'.");
+ if(!(msg instanceof FederatedRequest[]))
+ throw new DMLRuntimeException(
+ "FederatedWorkerHandler: Received object no instance of 'FederatedRequest[]'.");
FederatedRequest[] requests = (FederatedRequest[]) msg;
- FederatedResponse response = null; //last response
-
- for( int i=0; i<requests.length; i++ ) {
+ FederatedResponse response = null; // last response
+
+ for(int i = 0; i < requests.length; i++) {
FederatedRequest request = requests[i];
- if( log.isInfoEnabled() ){
- log.info("Executing command " + (i+1) + "/" + requests.length + ": " + request.getType().name());
- if( log.isDebugEnabled() ){
+ if(log.isInfoEnabled()) {
+ log.info("Executing command " + (i + 1) + "/" + requests.length + ": " + request.getType().name());
+ if(log.isDebugEnabled()) {
log.debug("full command: " + request.toString());
}
}
PrivacyMonitor.setCheckPrivacy(request.checkPrivacy());
PrivacyMonitor.clearCheckedConstraints();
-
- //execute command and handle privacy constraints
+
+ // execute command and handle privacy constraints
FederatedResponse tmp = executeCommand(request);
conditionalAddCheckedConstraints(request, tmp);
-
- //select the response for the entire batch of requests
- if (!tmp.isSuccessful()) {
- log.error("Command " + request.getType() + " failed: "
- + tmp.getErrorMessage() + "full command: \n" + request.toString());
- response = (response == null || response.isSuccessful())
- ? tmp : response; //return first error
+
+ // select the response for the entire batch of requests
+ if(!tmp.isSuccessful()) {
+ log.error("Command " + request.getType() + " failed: " + tmp.getErrorMessage() + "full command: \n"
+ + request.toString());
+ response = (response == null || response.isSuccessful()) ? tmp : response; // return first error
}
- else if( request.getType() == RequestType.GET_VAR ) {
- if( response != null && response.isSuccessful() )
+ else if(request.getType() == RequestType.GET_VAR) {
+ if(response != null && response.isSuccessful())
log.error("Multiple GET_VAR are not supported in single batch of requests.");
- response = tmp; //return last get result
+ response = tmp; // return last get result
}
- else if( response == null && i == requests.length-1 ) {
- response = tmp; //return last
+ else if(response == null && i == requests.length - 1) {
+ response = tmp; // return last
}
-
- if (DMLScript.STATISTICS && request.getType() == RequestType.CLEAR && Statistics.allowWorkerStatistics){
+
+ if(DMLScript.STATISTICS && request.getType() == RequestType.CLEAR && Statistics.allowWorkerStatistics) {
System.out.println("Federated Worker " + Statistics.display());
Statistics.reset();
}
@@ -127,17 +128,17 @@
return response;
}
- private static void conditionalAddCheckedConstraints(FederatedRequest request, FederatedResponse response){
- if ( request.checkPrivacy() )
+ private static void conditionalAddCheckedConstraints(FederatedRequest request, FederatedResponse response) {
+ if(request.checkPrivacy())
response.setCheckedConstraints(PrivacyMonitor.getCheckedConstraints());
}
private FederatedResponse executeCommand(FederatedRequest request) {
RequestType method = request.getType();
try {
- switch (method) {
+ switch(method) {
case READ_VAR:
- return readData(request); //matrix/frame
+ return readData(request); // matrix/frame
case PUT_VAR:
return putVariable(request);
case GET_VAR:
@@ -150,24 +151,22 @@
return execClear();
default:
String message = String.format("Method %s is not supported.", method);
- return new FederatedResponse(ResponseType.ERROR,
- new FederatedWorkerHandlerException(message));
+ return new FederatedResponse(ResponseType.ERROR, new FederatedWorkerHandlerException(message));
}
}
- catch (DMLPrivacyException | FederatedWorkerHandlerException ex) {
+ catch(DMLPrivacyException | FederatedWorkerHandlerException ex) {
return new FederatedResponse(ResponseType.ERROR, ex);
}
- catch (Exception ex) {
- return new FederatedResponse(ResponseType.ERROR,
- new FederatedWorkerHandlerException("Exception of type "
- + ex.getClass() + " thrown when processing request", ex));
+ catch(Exception ex) {
+ return new FederatedResponse(ResponseType.ERROR, new FederatedWorkerHandlerException(
+ "Exception of type " + ex.getClass() + " thrown when processing request", ex));
}
}
-
+
private FederatedResponse readData(FederatedRequest request) {
checkNumParams(request.getNumParams(), 2);
String filename = (String) request.getParam(0);
- DataType dt = DataType.valueOf((String)request.getParam(1));
+ DataType dt = DataType.valueOf((String) request.getParam(1));
return readData(filename, dt, request.getID(), request.getTID());
}
@@ -175,7 +174,7 @@
MatrixCharacteristics mc = new MatrixCharacteristics();
mc.setBlocksize(ConfigurationManager.getBlocksize());
CacheableData<?> cd;
- switch (dataType) {
+ switch(dataType) {
case MATRIX:
cd = new MatrixObject(Types.ValueType.FP64, filename);
break;
@@ -183,93 +182,102 @@
cd = new FrameObject(filename);
break;
default:
- // should NEVER happen (if we keep request codes in sync with actual behaviour)
+ // should NEVER happen (if we keep request codes in sync with actual behavior)
return new FederatedResponse(ResponseType.ERROR,
new FederatedWorkerHandlerException("Could not recognize datatype"));
}
-
- // read metadata
+
FileFormat fmt = null;
boolean header = false;
+ FileSystem fs = null;
try {
String mtdname = DataExpression.getMTDFileName(filename);
Path path = new Path(mtdname);
- FileSystem fs = IOUtilFunctions.getFileSystem(mtdname); //no auto-close
- try (BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) {
+ fs = IOUtilFunctions.getFileSystem(mtdname);
+ try(BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) {
JSONObject mtd = JSONHelper.parse(br);
- if (mtd == null)
+ if(mtd == null)
return new FederatedResponse(ResponseType.ERROR,
new FederatedWorkerHandlerException("Could not parse metadata file"));
mc.setRows(mtd.getLong(DataExpression.READROWPARAM));
mc.setCols(mtd.getLong(DataExpression.READCOLPARAM));
if(mtd.containsKey(DataExpression.READNNZPARAM))
mc.setNonZeros(mtd.getLong(DataExpression.READNNZPARAM));
- if (mtd.has(DataExpression.DELIM_HAS_HEADER_ROW))
+ if(mtd.has(DataExpression.DELIM_HAS_HEADER_ROW))
header = mtd.getBoolean(DataExpression.DELIM_HAS_HEADER_ROW);
cd = (CacheableData<?>) PrivacyPropagator.parseAndSetPrivacyConstraint(cd, mtd);
fmt = FileFormat.safeValueOf(mtd.getString(DataExpression.FORMAT_TYPE));
}
}
- catch (Exception ex) {
+ catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
-
- //put meta data object in symbol table, read on first operation
+ finally {
+ if(fs != null)
+ try {
+ fs.close();
+ }
+ catch(IOException e) {
+ return new FederatedResponse(ResponseType.ERROR, id);
+ }
+ }
+
+ // put meta data object in symbol table, read on first operation
cd.setMetaData(new MetaDataFormat(mc, fmt));
// TODO send FileFormatProperties with request and use them for CSV, this is currently a workaround so reading
- // of CSV files works
- cd.setFileFormatProperties(new FileFormatPropertiesCSV(header, DataExpression.DEFAULT_DELIM_DELIMITER,
- DataExpression.DEFAULT_DELIM_SPARSE));
- cd.enableCleanup(false); //guard against deletion
+ // of CSV files works
+ if(fmt == FileFormat.CSV)
+ cd.setFileFormatProperties(new FileFormatPropertiesCSV(header, DataExpression.DEFAULT_DELIM_DELIMITER,
+ DataExpression.DEFAULT_DELIM_SPARSE));
+ cd.enableCleanup(false); // guard against deletion
_ecm.get(tid).setVariable(String.valueOf(id), cd);
-
- if (dataType == Types.DataType.FRAME) {
+
+ if(dataType == Types.DataType.FRAME) {
FrameObject frameObject = (FrameObject) cd;
frameObject.acquireRead();
- frameObject.refreshMetaData(); //get block schema
+ frameObject.refreshMetaData(); // get block schema
frameObject.release();
- return new FederatedResponse(ResponseType.SUCCESS,
- new Object[] {id, frameObject.getSchema()});
+ return new FederatedResponse(ResponseType.SUCCESS, new Object[] {id, frameObject.getSchema()});
}
return new FederatedResponse(ResponseType.SUCCESS, id);
}
-
+
private FederatedResponse putVariable(FederatedRequest request) {
checkNumParams(request.getNumParams(), 1);
String varname = String.valueOf(request.getID());
ExecutionContext ec = _ecm.get(request.getTID());
- if( ec.containsVariable(varname) ) {
- return new FederatedResponse(ResponseType.ERROR,
- "Variable "+request.getID()+" already existing.");
+ if(ec.containsVariable(varname)) {
+ return new FederatedResponse(ResponseType.ERROR, "Variable " + request.getID() + " already existing.");
}
-
- //wrap transferred cache block into cacheable data
+
+ // wrap transferred cache block into cacheable data
Data data;
- if( request.getParam(0) instanceof CacheBlock )
+ if(request.getParam(0) instanceof CacheBlock)
data = ExecutionContext.createCacheableData((CacheBlock) request.getParam(0));
- else if( request.getParam(0) instanceof ScalarObject )
+ else if(request.getParam(0) instanceof ScalarObject)
data = (ScalarObject) request.getParam(0);
- else if( request.getParam(0) instanceof ListObject )
+ else if(request.getParam(0) instanceof ListObject)
data = (ListObject) request.getParam(0);
else
- throw new DMLRuntimeException("FederatedWorkerHandler: Unsupported object type, has to be of type CacheBlock or ScalarObject");
-
- //set variable and construct empty response
+ throw new DMLRuntimeException(
+ "FederatedWorkerHandler: Unsupported object type, has to be of type CacheBlock or ScalarObject");
+
+ // set variable and construct empty response
ec.setVariable(varname, data);
return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
}
-
+
private FederatedResponse getVariable(FederatedRequest request) {
checkNumParams(request.getNumParams(), 0);
ExecutionContext ec = _ecm.get(request.getTID());
- if( !ec.containsVariable(String.valueOf(request.getID())) ) {
+ if(!ec.containsVariable(String.valueOf(request.getID()))) {
return new FederatedResponse(ResponseType.ERROR,
- "Variable "+request.getID()+" does not exist at federated worker.");
+ "Variable " + request.getID() + " does not exist at federated worker.");
}
- //get variable and construct response
+ // get variable and construct response
Data dataObject = ec.getVariable(String.valueOf(request.getID()));
dataObject = PrivacyMonitor.handlePrivacy(dataObject);
- switch (dataObject.getDataType()) {
+ switch(dataObject.getDataType()) {
case TENSOR:
case MATRIX:
case FRAME:
@@ -280,20 +288,19 @@
case SCALAR:
return new FederatedResponse(ResponseType.SUCCESS, dataObject);
default:
- return new FederatedResponse(ResponseType.ERROR,
- new FederatedWorkerHandlerException("Unsupported return datatype " + dataObject.getDataType().name()));
+ return new FederatedResponse(ResponseType.ERROR, new FederatedWorkerHandlerException(
+ "Unsupported return datatype " + dataObject.getDataType().name()));
}
}
-
+
private FederatedResponse execInstruction(FederatedRequest request) {
ExecutionContext ec = _ecm.get(request.getTID());
BasicProgramBlock pb = new BasicProgramBlock(null);
pb.getInstructions().clear();
- Instruction receivedInstruction = InstructionParser
- .parseSingleInstruction((String)request.getParam(0));
+ Instruction receivedInstruction = InstructionParser.parseSingleInstruction((String) request.getParam(0));
pb.getInstructions().add(receivedInstruction);
try {
- pb.execute(ec); //execute single instruction
+ pb.execute(ec); // execute single instruction
}
catch(Exception ex) {
return new FederatedResponse(ResponseType.ERROR, new FederatedWorkerHandlerException(
@@ -301,19 +308,17 @@
}
return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
}
-
+
private FederatedResponse execUDF(FederatedRequest request) {
checkNumParams(request.getNumParams(), 1);
ExecutionContext ec = _ecm.get(request.getTID());
-
- //get function and input parameters
+
+ // get function and input parameters
FederatedUDF udf = (FederatedUDF) request.getParam(0);
- Data[] inputs = Arrays.stream(udf.getInputIDs())
- .mapToObj(id -> ec.getVariable(String.valueOf(id)))
- .map(PrivacyMonitor::handlePrivacy)
- .toArray(Data[]::new);
-
- //execute user-defined function
+ Data[] inputs = Arrays.stream(udf.getInputIDs()).mapToObj(id -> ec.getVariable(String.valueOf(id)))
+ .map(PrivacyMonitor::handlePrivacy).toArray(Data[]::new);
+
+ // execute user-defined function
try {
return udf.execute(ec, inputs);
}
@@ -333,9 +338,9 @@
}
return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
}
-
+
private static void checkNumParams(int actual, int... expected) {
- if (Arrays.stream(expected).anyMatch(x -> x == actual))
+ if(Arrays.stream(expected).anyMatch(x -> x == actual))
return;
throw new DMLRuntimeException("FederatedWorkerHandler: Received wrong amount of params:" + " expected="
+ Arrays.toString(expected) + ", actual=" + actual);
@@ -350,14 +355,10 @@
private static class CloseListener implements ChannelFutureListener {
@Override
public void operationComplete(ChannelFuture channelFuture) throws InterruptedException {
- if (!channelFuture.isSuccess()){
+ if(!channelFuture.isSuccess()) {
log.error("Federated Worker Write failed");
- channelFuture
- .channel()
- .writeAndFlush(
- new FederatedResponse(ResponseType.ERROR,
- new FederatedWorkerHandlerException("Error while sending response.")))
- .channel().close().sync();
+ channelFuture.channel().writeAndFlush(new FederatedResponse(ResponseType.ERROR,
+ new FederatedWorkerHandlerException("Error while sending response."))).channel().close().sync();
}
else {
PrivacyMonitor.clearCheckedConstraints();
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index b44aa54..3076b9b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -19,8 +19,6 @@
package org.apache.sysds.runtime.instructions.fed;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -54,7 +52,7 @@
import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
public class FEDInstructionUtils {
- private static final Log LOG = LogFactory.getLog(FEDInstructionUtils.class.getName());
+ // private static final Log LOG = LogFactory.getLog(FEDInstructionUtils.class.getName());
// This is currently a rather simplistic to our solution of replacing instructions with their correct federated
// counterpart, since we do not propagate the information that a matrix is federated, therefore we can not decide
@@ -104,7 +102,6 @@
&& ec.containsVariable(instruction.input1)) {
MatrixObject mo1 = ec.getMatrixObject(instruction.input1);
-
if(instruction.getOpcode().equalsIgnoreCase("cm") && mo1.isFederated()) {
fedinst = CentralMomentFEDInstruction.parseInstruction(inst.getInstructionString());
} else if(inst.getOpcode().equalsIgnoreCase("qsort") && mo1.isFederated()) {
@@ -153,7 +150,6 @@
MatrixIndexingCPInstruction minst = (MatrixIndexingCPInstruction) inst;
if(inst.getOpcode().equalsIgnoreCase("rightIndex")
&& minst.input1.isMatrix() && ec.getCacheableData(minst.input1).isFederated()) {
- LOG.info("Federated Right Indexing");
fedinst = MatrixIndexingFEDInstruction.parseInstruction(minst.getInstructionString());
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
index a32fdf2..9b6f74e 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
@@ -30,6 +30,7 @@
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -41,6 +42,7 @@
private final static String TEST_DIR = "functions/federated/";
private final static String TEST_NAME = "FederatedYL2SVMTest";
+ private final static String TEST_NAME_2 = "FederatedYL2SVMTest2";
private final static String TEST_CLASS_DIR = TEST_DIR + FederatedYL2SVMTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
@@ -53,6 +55,7 @@
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"}));
}
@Parameterized.Parameters
@@ -60,21 +63,29 @@
// rows have to be even and > 1
return Arrays.asList(new Object[][] {
// {2, 1000}, {10, 100}, {100, 10}, {1000, 1}, {10, 2000},
- {2000, 10}});
+ {2000, 10}});
}
@Test
public void federatedL2SVMCP() {
- federatedL2SVM(Types.ExecMode.SINGLE_NODE);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, TEST_NAME);
}
- /*
- * TODO support SPARK execution mode -> RDDs and SPARK instructions lead to quite a few problems
- *
- * @Test public void federatedL2SVMSP() { federatedL2SVM(Types.ExecMode.SPARK); }
- */
+ @Test
+ public void federatedL2SVMCP_2() {
+ // This test is equal to the first tests, just with one worker location used instead.
+ // making all federated matrices FULL type.
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, TEST_NAME_2);
- public void federatedL2SVM(Types.ExecMode execMode) {
+ }
+
+ @Test
+ @Ignore
+ public void federatedL2SVMSP() {
+ federatedL2SVM(Types.ExecMode.SPARK, TEST_NAME);
+ }
+
+ public void federatedL2SVM(Types.ExecMode execMode, String testName) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
rtplatform = execMode;
@@ -82,7 +93,7 @@
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
}
- getAndLoadTestConfiguration(TEST_NAME);
+ getAndLoadTestConfiguration(testName);
String HOME = SCRIPT_DIR + TEST_DIR;
// write input matrices
@@ -110,18 +121,17 @@
Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
- TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ TestConfiguration config = availableTestConfigurations.get(testName);
loadTestConfiguration(config);
-
// Run reference dml script with normal matrix
- fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ fullDMLScriptName = HOME + testName + "Reference.dml";
programArgs = new String[] {"-args", input("X1"), input("X2"), input("Y1"), input("Y2"), expected("Z")};
LOG.debug(runTest(null));
// Run actual dml script with federated matrixz
- fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ fullDMLScriptName = HOME + testName + ".dml";
+ programArgs = new String[] {"-stats", "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols,
"in_Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
"in_Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "out=" + output("Z")};
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
index 810b882..62cfd32 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
@@ -18,10 +18,11 @@
*/
package org.apache.sysds.test.functions.federated.io;
-
import java.util.Arrays;
import java.util.Collection;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -38,7 +39,7 @@
@net.jcip.annotations.NotThreadSafe
public class FederatedReaderTest extends AutomatedTestBase {
- // private static final Log LOG = LogFactory.getLog(FederatedReaderTest.class.getName());
+ private static final Log LOG = LogFactory.getLog(FederatedReaderTest.class.getName());
private final static String TEST_DIR = "functions/federated/ioR/";
private final static String TEST_NAME = "FederatedReaderTest";
private final static String TEST_CLASS_DIR = TEST_DIR + FederatedReaderTest.class.getSimpleName() + "/";
@@ -65,15 +66,22 @@
}
@Test
- public void federatedSinglenodeRead() {
- federatedRead(Types.ExecMode.SINGLE_NODE);
+ public void federatedSingleNodeReadOneWorker() {
+ LOG.debug("1Federated");
+ federatedRead(Types.ExecMode.SINGLE_NODE, 1);
}
- public void federatedRead(Types.ExecMode execMode) {
+ @Test
+ public void federatedSingleNodeReadTwoWorker() {
+ LOG.debug("2Federated");
+ federatedRead(Types.ExecMode.SINGLE_NODE, 2);
+ }
+
+ public void federatedRead(Types.ExecMode execMode, int workerCount) {
Types.ExecMode oldPlatform = setExecMode(execMode);
getAndLoadTestConfiguration(TEST_NAME);
setOutputBuffering(true);
-
+
// write input matrices
int halfRows = rows / 2;
long[][] begins = new long[][] {new long[] {0, 0}, new long[] {halfRows, 0}};
@@ -91,18 +99,31 @@
Thread t2 = startLocalFedWorkerThread(port2);
String host = "localhost";
-
try {
- MatrixObject fed = FederatedTestObjectConstructor.constructFederatedInput(
- rows, cols, blocksize, host, begins, ends, new int[] {port1, port2},
- new String[] {input("X1"), input("X2")}, input("X.json"));
+ MatrixObject fed = FederatedTestObjectConstructor.constructFederatedInput(rows,
+ cols,
+ blocksize,
+ host,
+ begins,
+ ends,
+ workerCount == 2 ? new int[] {port1, port2} : new int[] {port1},
+ workerCount == 2 ? new String[] {input("X1"), input("X2")} : new String[] {input("X1")},
+ input("X.json"));
writeInputFederatedWithMTD("X.json", fed, null);
// Run reference dml script with normal matrix
- fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + (rowPartitioned ? "Row" : "Col")
- + "Reference.dml";
- programArgs = new String[] {"-stats", "-args", input("X1"), input("X2")};
+
+ if(workerCount == 1) {
+ fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + "1Reference.dml";
+ programArgs = new String[] {"-stats", "-args", input("X1")};
+ }
+ else {
+ fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME
+ + (rowPartitioned ? "Row" : "Col") + "2Reference.dml";
+ programArgs = new String[] {"-stats", "-args", input("X1"), input("X2")};
+ }
+
String refOut = runTest(null).toString();
-
+
// Run federated
fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "-args", input("X.json")};
@@ -111,7 +132,8 @@
Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
// Verify output
Assert.assertEquals(Double.parseDouble(refOut.split("\n")[0]),
- Double.parseDouble(out.split("\n")[0]), 0.00001);
+ Double.parseDouble(out.split("\n")[0]),
+ 0.00001);
}
catch(Exception e) {
e.printStackTrace();
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
index fc2c1dd..273ff0a 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
@@ -105,7 +105,7 @@
writeInputFederatedWithMTD("X.json", fed, null);
// Run reference dml script with normal matrix
fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + (rowPartitioned ? "Row" : "Col")
- + "Reference.dml";
+ + "2Reference.dml";
programArgs = new String[] {"-stats", "-args", input("X1"), input("X2")};
String refOut = runTest(null).toString();
diff --git a/src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml b/src/test/scripts/functions/federated/FederatedYL2SVMTest2.dml
similarity index 75%
copy from src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml
copy to src/test/scripts/functions/federated/FederatedYL2SVMTest2.dml
index 56c2316..4e72b49 100644
--- a/src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml
+++ b/src/test/scripts/functions/federated/FederatedYL2SVMTest2.dml
@@ -19,5 +19,9 @@
#
#-------------------------------------------------------------
-X = cbind(read($1), read($2))
-print(sum(X))
+X = federated(addresses=list($in_X1),
+ ranges=list(list(0, 0), list($rows / 2, $cols)))
+Y = federated(addresses=list($in_Y1),
+ ranges=list(list(0, 0), list($rows / 2, 1)))
+model = l2svm(X=X, Y=Y, intercept = FALSE, epsilon = 1e-12, lambda = 1, maxIterations = 100)
+write(model, $out)
diff --git a/src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml b/src/test/scripts/functions/federated/FederatedYL2SVMTest2Reference.dml
similarity index 86%
copy from src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml
copy to src/test/scripts/functions/federated/FederatedYL2SVMTest2Reference.dml
index 56c2316..486e856 100644
--- a/src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml
+++ b/src/test/scripts/functions/federated/FederatedYL2SVMTest2Reference.dml
@@ -19,5 +19,7 @@
#
#-------------------------------------------------------------
-X = cbind(read($1), read($2))
-print(sum(X))
+X = read($1)
+Y = read($3)
+model = l2svm(X=X, Y=Y, intercept = FALSE, epsilon = 1e-12, lambda = 1, maxIterations = 100)
+write(model, $5)
diff --git a/src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml b/src/test/scripts/functions/federated/io/FederatedReaderTest1Reference.dml
similarity index 96%
rename from src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml
rename to src/test/scripts/functions/federated/io/FederatedReaderTest1Reference.dml
index 56c2316..0eb8683 100644
--- a/src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml
+++ b/src/test/scripts/functions/federated/io/FederatedReaderTest1Reference.dml
@@ -19,5 +19,5 @@
#
#-------------------------------------------------------------
-X = cbind(read($1), read($2))
+X = read($1)
print(sum(X))
diff --git a/src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml b/src/test/scripts/functions/federated/io/FederatedReaderTestCol2Reference.dml
similarity index 95%
copy from src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml
copy to src/test/scripts/functions/federated/io/FederatedReaderTestCol2Reference.dml
index 56c2316..54e731b 100644
--- a/src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml
+++ b/src/test/scripts/functions/federated/io/FederatedReaderTestCol2Reference.dml
@@ -19,5 +19,6 @@
#
#-------------------------------------------------------------
-X = cbind(read($1), read($2))
-print(sum(X))
+Y = cbind(read($1), read($2))
+print(sum(Y))
+
diff --git a/src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml b/src/test/scripts/functions/federated/io/FederatedReaderTestRow2Reference.dml
similarity index 95%
copy from src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml
copy to src/test/scripts/functions/federated/io/FederatedReaderTestRow2Reference.dml
index 56c2316..a2cb4f6 100644
--- a/src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml
+++ b/src/test/scripts/functions/federated/io/FederatedReaderTestRow2Reference.dml
@@ -19,5 +19,6 @@
#
#-------------------------------------------------------------
-X = cbind(read($1), read($2))
-print(sum(X))
+Y = rbind(read($1), read($2))
+print(sum(Y))
+
diff --git a/src/test/scripts/functions/federated/io/FederatedReaderTestRowReference.dml b/src/test/scripts/functions/federated/io/FederatedReaderTestRowReference.dml
deleted file mode 100644
index 5059e4d..0000000
--- a/src/test/scripts/functions/federated/io/FederatedReaderTestRowReference.dml
+++ /dev/null
@@ -1,23 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-X = rbind(read($1), read($2))
-print(sum(X))