[SYSTEMDS-2704] Federated Read
Exactly what it says in the title. Can read a file if the format is
correct json and the file has a mtd file next to it with the format type:
"federated".
diff --git a/pom.xml b/pom.xml
index 39b6cff..3c3bc07 100644
--- a/pom.xml
+++ b/pom.xml
@@ -979,6 +979,12 @@
</dependency>
<dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ <version>2.11.3</version>
+ </dependency>
+
+ <dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index 0386fe9..c76cd1c 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -493,6 +493,7 @@
LIBSVM, // text libsvm sparse row representation
JSONL, // text nested JSON (Line) representation
BINARY, // binary block representation (dense/sparse/ultra-sparse)
+ FEDERATED, // A federated matrix
PROTO; // protocol buffer representation
public boolean isIJVFormat() {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index d96b700..da98df3 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -40,8 +40,10 @@
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.io.FileFormatProperties;
+import org.apache.sysds.runtime.io.ReaderWriterFederated;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -415,6 +417,7 @@
}
+
@Override
protected MatrixBlock readBlobFromHDFS(String fname, long[] dims)
throws IOException
@@ -430,8 +433,14 @@
+ ", dimensions: [" + mc.getRows() + ", " + mc.getCols() + ", " + mc.getNonZeros() + "]");
begin = System.currentTimeMillis();
}
-
- //read matrix and maintain meta data
+
+ // If the file format is Federated use the federated reader.
+ if(iimd.getFileFormat() == FileFormat.FEDERATED){
+ InitFEDInstruction.federateMatrix(this, ReaderWriterFederated.read(fname,mc));
+ }
+
+ // Read matrix and maintain meta data,
+ // if the MatrixObject is federated there is nothing extra to read, and therefore only acquire read and release
MatrixBlock newData = isFederated() ? acquireReadAndRelease() :
DataConverter.readMatrixFromHDFS(fname, iimd.getFileFormat(), rlen,
clen, mc.getBlocksize(), mc.getNonZeros(), getFileFormatProperties());
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 4a8387f..5a40456 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -19,6 +19,20 @@
package org.apache.sysds.runtime.controlprogram.federated;
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.Future;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
@@ -33,23 +47,10 @@
import io.netty.handler.codec.serialization.ObjectEncoder;
import io.netty.util.concurrent.Promise;
-import org.apache.log4j.Logger;
-import org.apache.sysds.common.Types;
-import org.apache.sysds.conf.DMLConfig;
-import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
-
-import java.net.InetSocketAddress;
-import java.util.ArrayList;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
-import java.util.concurrent.Future;
-
public class FederatedData {
- protected final static Logger log = Logger.getLogger(FederatedWorkerHandler.class);
- private final static Set<InetSocketAddress> _allFedSites = new HashSet<>();
+ private static final Log LOG = LogFactory.getLog(FederatedData.class.getName());
+ private static final Set<InetSocketAddress> _allFedSites = new HashSet<>();
private final Types.DataType _dataType;
private final InetSocketAddress _address;
@@ -82,6 +83,10 @@
public String getFilepath() {
return _filepath;
}
+
+ public Types.DataType getDataType(){
+ return _dataType;
+ }
public boolean isInitialized() {
return _varID != -1;
@@ -172,7 +177,7 @@
FederationUtils.waitFor(ret);
}
catch(Exception ex) {
- log.warn("Failed to execute CLEAR request on existing federated sites.", ex);
+ LOG.warn("Failed to execute CLEAR request on existing federated sites.", ex);
}
finally {
resetFederatedSites();
@@ -204,4 +209,14 @@
_workerGroup.shutdownGracefully();
}
}
+
+ @Override
+ public String toString(){
+ StringBuilder sb = new StringBuilder();
+ sb.append(this.getClass().getSimpleName().toString());
+ sb.append(" "+ _dataType);
+ sb.append(" "+_address.toString());
+ sb.append(":" + _filepath);
+ return sb.toString();
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 6d2e7c1..04251fc 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -86,6 +86,10 @@
public FederatedRange[] getFederatedRanges() {
return _fedMap.keySet().toArray(new FederatedRange[0]);
}
+
+ public Map<FederatedRange, FederatedData> getFedMapping(){
+ return _fedMap;
+ }
public FederatedRequest broadcast(CacheableData<?> data) {
//prepare single request for all federated data
@@ -327,4 +331,13 @@
return null;
}
}
+
+ @Override
+ public String toString(){
+ StringBuilder sb = new StringBuilder();
+ sb.append("Fed Map: " + _type);
+ sb.append("\t ID:" + _ID);
+ sb.append("\n"+ _fedMap);
+ return sb.toString();
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 9ae5014..e42f192 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -19,6 +19,18 @@
package org.apache.sysds.runtime.instructions.fed;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.MalformedURLException;
+import java.net.URL;
+import java.net.UnknownHostException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+import java.util.concurrent.Future;
+
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
@@ -32,8 +44,8 @@
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
@@ -41,19 +53,9 @@
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.StringObject;
-import java.net.InetAddress;
-import java.net.InetSocketAddress;
-import java.net.MalformedURLException;
-import java.net.URL;
-import java.net.UnknownHostException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-import java.util.TreeMap;
-import java.util.concurrent.Future;
-
public class InitFEDInstruction extends FEDInstruction {
+
+ // private static final Log LOG = LogFactory.getLog(InitFEDInstruction.class.getName());
public static final String FED_MATRIX_IDENTIFIER = "matrix";
public static final String FED_FRAME_IDENTIFIER = "frame";
@@ -202,7 +204,8 @@
}
}
- public void federateMatrix(MatrixObject output, List<Pair<FederatedRange, FederatedData>> workers) {
+ public static void federateMatrix(MatrixObject output, List<Pair<FederatedRange, FederatedData>> workers) {
+
Map<FederatedRange, FederatedData> fedMapping = new TreeMap<>();
for (Pair<FederatedRange, FederatedData> t : workers) {
fedMapping.put(t.getLeft(), t.getRight());
@@ -238,7 +241,7 @@
output.getFedMapping().setType(rowPartitioned ? FType.ROW : colPartitioned ? FType.COL : FType.OTHER);
}
- public void federateFrame(FrameObject output, List<Pair<FederatedRange, FederatedData>> workers) {
+ public static void federateFrame(FrameObject output, List<Pair<FederatedRange, FederatedData>> workers) {
Map<FederatedRange, FederatedData> fedMapping = new TreeMap<>();
for (Pair<FederatedRange, FederatedData> t : workers) {
fedMapping.put(t.getLeft(), t.getRight());
diff --git a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java
new file mode 100644
index 0000000..41ae5fd
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.runtime.io;
+
+import static org.junit.Assert.fail;
+
+import java.io.BufferedWriter;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.net.InetSocketAddress;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.stream.Collectors;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
+/**
+ * This class serves as the reader for federated objects. To read the files a mdt file is required. The reader is
+ * different from the other readers in the since that it does not return a MatrixBlock. but a Matrix Object wrapper,
+ * containing the federated Mapping.
+ *
+ * This means in practice that it circumvent the other reading code. See more in:
+ *
+ * org.apache.sysds.runtime.controlprogram.caching.MatrixObject.readBlobFromHDFS()
+ *
+ */
+public class ReaderWriterFederated {
+ private static final Log LOG = LogFactory.getLog(ReaderWriterFederated.class.getName());
+
+ /**
+ * Read a federated map from disk, It is not initialized before it is used in:
+ *
+ * org.apache.sysds.runtime.instructions.fed.InitFEDInstruction
+ *
+ * @param file The file to read (defaults to HDFS)
+ * @param mc The data characteristics of the file, that can be read from the mtd file.
+ * @return A List of federatedRanges and Federated Data
+ */
+ public static List<Pair<FederatedRange, FederatedData>> read(String file, DataCharacteristics mc) {
+ LOG.debug("Reading federated map from " + file);
+ try {
+ JobConf job = new JobConf(ConfigurationManager.getCachedJobConf());
+ Path path = new Path(file);
+ FileSystem fs = IOUtilFunctions.getFileSystem(path, job);
+ FSDataInputStream data = fs.open(path);
+ ObjectMapper mapper = new ObjectMapper();
+ List<FederatedDataAddress> obj = mapper.readValue(data, new TypeReference<List<FederatedDataAddress>>() {
+ });
+ return obj.stream().map(x -> x.convert()).collect(Collectors.toList());
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException("Unable to read federated matrix (" + file + ")", e);
+ }
+ }
+
+ /**
+ * TODO add writing to each of the federated locations so that they also save their matrices.
+ *
+ * Currently this would write the federated matrix to disk only locally.
+ *
+ * @param file The file to save to, (defaults to HDFS paths)
+ * @param fedMap The federated map to save.
+ */
+ public static void write(String file, FederationMap fedMap) {
+ if(fedMap.getID() != 0) {
+ // TODO add writing to remote to allow this anyway.
+ throw new DMLRuntimeException(
+ "Invalid to save federated maps with ID's higher than 0, since they are modified.");
+ }
+ LOG.debug("Writing federated map to " + file);
+ try {
+ JobConf job = new JobConf(ConfigurationManager.getCachedJobConf());
+ Path path = new Path(file);
+ FileSystem fs = IOUtilFunctions.getFileSystem(path, job);
+ DataOutputStream out = fs.create(path, true);
+ ObjectMapper mapper = new ObjectMapper();
+ // FileOutputStream fileOutputStream = new FileOutputStream("post.json");
+ // String postJson = mapper.writeValueAsString(fedMap);
+ FederatedDataAddress[] outObjects = parseMap(fedMap.getFedMapping());
+ try(BufferedWriter pw = new BufferedWriter(new OutputStreamWriter(out))) {
+ mapper.writeValue(pw, outObjects);
+ }
+ }
+ catch(IOException e) {
+ fail("Unable to write test federated matrix to (" + file + "): " + e.getMessage());
+ }
+ }
+
+ private static FederatedDataAddress[] parseMap(Map<FederatedRange, FederatedData> map) {
+ FederatedDataAddress[] res = new FederatedDataAddress[map.size()];
+ int i = 0;
+ for(Entry<FederatedRange, FederatedData> ent : map.entrySet()) {
+ res[i++] = new FederatedDataAddress(ent.getKey(), ent.getValue());
+ }
+ return res;
+ }
+
+ /**
+ * This class is used for easy serialization from json using Jackson. The warnings are suppressed because the
+ * setters and getters only is used inside Jackson.
+ */
+ @SuppressWarnings("unused")
+ private static class FederatedDataAddress {
+ private Types.DataType _dataType;
+ private InetSocketAddress _address;
+ private String _filepath;
+ private long[] _begin;
+ private long[] _end;
+
+ public FederatedDataAddress() {
+ }
+
+ protected FederatedDataAddress(FederatedRange fr, FederatedData fd) {
+ _dataType = fd.getDataType();
+ _address = fd.getAddress();
+ _filepath = fd.getFilepath();
+ _begin = fr.getBeginDims();
+ _end = fr.getEndDims();
+ }
+
+ protected Pair<FederatedRange, FederatedData> convert() {
+ FederatedRange fr = new FederatedRange(_begin, _end);
+ FederatedData fd = new FederatedData(_dataType, _address, _filepath);
+ return new ImmutablePair<>(fr, fd);
+ }
+
+ public String getFilepath() {
+ return _filepath;
+ }
+
+ public void setFilepath(String filePath) {
+ _filepath = filePath;
+ }
+
+ public Types.DataType getDataType() {
+ return _dataType;
+ }
+
+ public void setDataType(Types.DataType dataType) {
+ _dataType = dataType;
+ }
+
+ public InetSocketAddress getAddress() {
+ return _address;
+ }
+
+ public void setAddress(InetSocketAddress address) {
+ _address = address;
+ }
+
+ public long[] getBegin() {
+ return _begin;
+ }
+
+ public void setBegin(long[] begin) {
+ _begin = begin;
+ }
+
+ public long[] getEnd() {
+ return _end;
+ }
+
+ public void setEnd(long[] end) {
+ _end = end;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(_dataType);
+ sb.append(" ");
+ sb.append(_address);
+ sb.append(" ");
+ sb.append(_filepath);
+ sb.append(" ");
+ sb.append(Arrays.toString(_begin));
+ sb.append(" ");
+ sb.append(Arrays.toString(_end));
+ return sb.toString();
+ }
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index b78bf99..b97d907 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -56,10 +56,13 @@
import org.apache.sysds.parser.ParseException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
import org.apache.sysds.runtime.io.FrameReader;
import org.apache.sysds.runtime.io.FrameReaderFactory;
+import org.apache.sysds.runtime.io.ReaderWriterFederated;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
@@ -200,9 +203,10 @@
private boolean isOutAndExpectedDeletionDisabled = false;
private static boolean outputBuffering = false;
-
+
static {
- java.io.InputStream inputStream = Thread.currentThread().getContextClassLoader().getResourceAsStream("my.properties");
+ java.io.InputStream inputStream = Thread.currentThread().getContextClassLoader()
+ .getResourceAsStream("my.properties");
java.util.Properties properties = new Properties();
try {
properties.load(inputStream);
@@ -338,7 +342,7 @@
protected File getCodegenConfigFile(String parent, CodegenTestType type) {
// Instrumentation in this test's output log to show custom configuration file used for template.
File tmp = new File(parent, type.getCodgenConfig());
- if( LOG.isInfoEnabled() )
+ if(LOG.isInfoEnabled())
LOG.info("This test case overrides default configuration with " + tmp.getPath());
return tmp;
}
@@ -381,6 +385,7 @@
protected double[][] getRandomMatrix(int rows, int cols, double min, double max, double sparsity, long seed) {
return TestUtils.generateTestMatrix(rows, cols, min, max, sparsity, seed);
}
+
/**
* <p>
* Generates a random matrix with the specified characteristics and returns it as a two dimensional array.
@@ -395,9 +400,11 @@
* @param delta The minimum value in between values.
* @return two dimensional array containing random matrix
*/
- protected double[][] getRandomMatrix(int rows, int cols, double min, double max, double sparsity, long seed, double delta) {
+ protected double[][] getRandomMatrix(int rows, int cols, double min, double max, double sparsity, long seed,
+ double delta) {
return TestUtils.generateTestMatrix(rows, cols, min, max, sparsity, seed, delta);
}
+
/**
* <p>
* Generates a random matrix with the specified characteristics which does not contain any zero values and returns
@@ -549,6 +556,34 @@
return matrix;
}
+ protected void writeInputFederatedWithMTD(String name, MatrixObject fm, PrivacyConstraint privacyConstraint){
+ writeFederatedInputMatrix(name, fm.getFedMapping());
+ MatrixCharacteristics mc = (MatrixCharacteristics)fm.getDataCharacteristics();
+ try {
+ String completeMTDPath = baseDirectory + INPUT_DIR + name + ".mtd";
+ HDFSTool.writeMetaDataFile(completeMTDPath, ValueType.FP64, mc, FileFormat.FEDERATED, privacyConstraint);
+ }
+ catch(IOException e) {
+ e.printStackTrace();
+ throw new RuntimeException(e);
+ }
+
+ }
+
+ protected void writeFederatedInputMatrix(String name, FederationMap fedMap){
+ String completePath = baseDirectory + INPUT_DIR + name;
+ try {
+ cleanupExistingData(baseDirectory + INPUT_DIR + name, false);
+ }
+ catch(IOException e) {
+ e.printStackTrace();
+ throw new RuntimeException(e);
+ }
+
+ ReaderWriterFederated.write(completePath, fedMap);
+ inputDirectories.add(baseDirectory + INPUT_DIR + name);
+ }
+
/**
* <p>
* Adds a matrix to the input path and writes it to a file.
@@ -713,7 +748,7 @@
}
public HashMap<CellIndex, Double> readRMatrixFromFS(String fileName) {
- if( LOG.isInfoEnabled() )
+ if(LOG.isInfoEnabled())
LOG.info("R script out: " + baseDirectory + EXPECTED_DIR + cacheDir + fileName);
return TestUtils.readRMatrixFromFS(baseDirectory + EXPECTED_DIR + cacheDir + fileName);
}
@@ -761,7 +796,7 @@
}
public HashMap<CellIndex, Double> readRScalarFromFS(String fileName) {
- if( LOG.isInfoEnabled() )
+ if(LOG.isInfoEnabled())
LOG.info("R script out: " + baseDirectory + EXPECTED_DIR + cacheDir + fileName);
return TestUtils.readRScalarFromFS(baseDirectory + EXPECTED_DIR + cacheDir + fileName);
}
@@ -799,16 +834,18 @@
}
/**
- * Call readDMLMetaDataValue but fail test in case of JSONException or NullPointerException.
- * @param fileName of metadata file
+ * Call readDMLMetaDataValue but fail test in case of JSONException or NullPointerException.
+ *
+ * @param fileName of metadata file
* @param outputDir directory of metadata file
- * @param key key to find in metadata
+ * @param key key to find in metadata
* @return value retrieved from metadata for the given key
*/
- public static String readDMLMetaDataValueCatchException(String fileName, String outputDir, String key){
+ public static String readDMLMetaDataValueCatchException(String fileName, String outputDir, String key) {
try {
return readDMLMetaDataValue(fileName, outputDir, key);
- } catch (JSONException | NullPointerException e){
+ }
+ catch(JSONException | NullPointerException e) {
fail("Privacy constraint not written to output metadata file:\n" + e);
return null;
}
@@ -911,7 +948,7 @@
FileUtils.write(getCurConfigFile(), configContents, "UTF-8");
- if( LOG.isDebugEnabled() )
+ if(LOG.isDebugEnabled())
LOG.debug("This test case will use SystemDS config file %s\n" + getCurConfigFile());
}
catch(IOException e) {
@@ -958,7 +995,8 @@
*/
protected void runRScript(boolean newWay) {
- String executionFile = sourceDirectory + selectedTest + ".R";;
+ String executionFile = sourceDirectory + selectedTest + ".R";
+ ;
if(fullRScriptName != null)
executionFile = fullRScriptName;
@@ -1012,7 +1050,7 @@
if(outputFiles != null && outputFiles.length > 0) {
expectedFile = new File(expectedDir.getPath() + "/" + outputFiles[0]);
if(expectedFile.canRead()) {
- if( LOG.isInfoEnabled() )
+ if(LOG.isInfoEnabled())
LOG.info("Skipping R script cmd: " + cmd);
return;
}
@@ -1023,7 +1061,7 @@
String errorString;
try {
long t0 = System.nanoTime();
- if( LOG.isInfoEnabled() ) {
+ if(LOG.isInfoEnabled()) {
LOG.info("starting R script");
LOG.debug("R cmd: " + cmd);
}
@@ -1031,11 +1069,11 @@
outputR = IOUtils.toString(child.getInputStream());
errorString = IOUtils.toString(child.getErrorStream());
- if( LOG.isTraceEnabled() ) {
+ if(LOG.isTraceEnabled()) {
LOG.trace("Standard Output from R:" + outputR);
LOG.trace("Standard Error from R:" + errorString);
}
-
+
//
// To give any stream enough time to print all data, otherwise there
// are situations where the test case fails, even before everything
@@ -1193,9 +1231,9 @@
String name = "";
final StackTraceElement[] ste = Thread.currentThread().getStackTrace();
- for(int i=0; i < ste.length; i++) {
+ for(int i = 0; i < ste.length; i++) {
if(ste[i].getMethodName().equalsIgnoreCase("invoke0"))
- name = ste[i-1].getClassName() + "." + ste[i-1].getMethodName();
+ name = ste[i - 1].getClassName() + "." + ste[i - 1].getMethodName();
}
LOG.info("Test method name: " + name);
@@ -1242,15 +1280,15 @@
TestUtils.printDMLScript(fullDMLScriptName);
}
}
-
+
ByteArrayOutputStream buff = outputBuffering ? new ByteArrayOutputStream() : null;
PrintStream old = System.out;
if(outputBuffering)
System.setOut(new PrintStream(buff));
-
+
try {
String[] dmlScriptArgs = args.toArray(new String[args.size()]);
- if( LOG.isTraceEnabled() )
+ if(LOG.isTraceEnabled())
LOG.trace("arguments to DMLScript: " + Arrays.toString(dmlScriptArgs));
main(dmlScriptArgs);
@@ -1262,7 +1300,7 @@
fail("expected exception which has not been raised: " + expectedException);
}
catch(Exception | Error e) {
- if( !outputBuffering )
+ if(!outputBuffering)
e.printStackTrace();
if(errMessage != null && !errMessage.equals("")) {
boolean result = rCompareException(exceptionExpected, errMessage, e, false);
@@ -1274,7 +1312,7 @@
StringBuilder errorMessage = new StringBuilder();
errorMessage.append("\nfailed to run script: " + executionFile);
errorMessage.append("\nStandard Out:");
- if( outputBuffering )
+ if(outputBuffering)
errorMessage.append("\n" + buff);
errorMessage.append("\nStackTrace:");
errorMessage.append(getStackTraceString(e, 0));
@@ -1293,9 +1331,7 @@
* @param args command-line arguments
* @throws IOException if an IOException occurs in the hadoop GenericOptionsParser
*/
- public static void main(String[] args)
- throws IOException, ParseException, DMLScriptException
- {
+ public static void main(String[] args) throws IOException, ParseException, DMLScriptException {
Configuration conf = new Configuration(ConfigurationManager.getCachedJobConf());
String[] otherArgs = new GenericOptionsParser(conf, args).getRemainingArgs();
DMLScript.executeScript(conf, otherArgs);
@@ -1347,27 +1383,25 @@
Process process = null;
String separator = System.getProperty("file.separator");
String classpath = System.getProperty("java.class.path");
- String path = System.getProperty("java.home")
- + separator + "bin" + separator + "java";
- ProcessBuilder processBuilder = new ProcessBuilder(path, "-cp",
- classpath, DMLScript.class.getName(), "-w", Integer.toString(port), "-stats");
+ String path = System.getProperty("java.home") + separator + "bin" + separator + "java";
+ ProcessBuilder processBuilder = new ProcessBuilder(path, "-cp", classpath, DMLScript.class.getName(), "-w",
+ Integer.toString(port), "-stats");
- try{
+ try {
process = processBuilder.start();
// Give some time to startup the worker.
sleep(FED_WORKER_WAIT);
- } catch (IOException | InterruptedException e){
+ }
+ catch(IOException | InterruptedException e) {
e.printStackTrace();
}
return process;
}
/**
- * Start a thread for a worker. This will share the same JVM,
- * so all static variables will be shared.!
+ * Start a thread for a worker. This will share the same JVM, so all static variables will be shared.!
*
- * Also when using the local Fed Worker thread the statistics printing,
- * and clearing from the worker is disabled.
+ * Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is disabled.
*
* @param port Port to use
* @return the thread associated with the worker.
@@ -1405,12 +1439,13 @@
/**
* Start java worker in same JVM.
+ *
* @param args the command line arguments
* @return the thread associated with the process.s
*/
public static Thread startLocalFedWorkerWithArgs(String[] args) {
Thread t = null;
-
+
try {
t = new Thread(() -> {
try {
@@ -1646,14 +1681,15 @@
public void tearDown() {
LOG.trace("Duration: " + (System.currentTimeMillis() - lTimeBeforeTest) + "ms");
-// assertTrue("expected String did not occur: " + expectedStdOut,
-// iExpectedStdOutState == 0 || iExpectedStdOutState == 2);
-// assertTrue("expected String did not occur (stderr): " + expectedStdErr,
-// iExpectedStdErrState == 0 || iExpectedStdErrState == 2);
-// assertFalse("unexpected String occurred: " + unexpectedStdOut, iUnexpectedStdOutState == 1);
+ // assertTrue("expected String did not occur: " + expectedStdOut,
+ // iExpectedStdOutState == 0 || iExpectedStdOutState == 2);
+ // assertTrue("expected String did not occur (stderr): " + expectedStdErr,
+ // iExpectedStdErrState == 0 || iExpectedStdErrState == 2);
+ // assertFalse("unexpected String occurred: " + unexpectedStdOut, iUnexpectedStdOutState == 1);
TestUtils.displayAssertionBuffer();
if(!isOutAndExpectedDeletionDisabled()) {
+
TestUtils.removeHDFSDirectories(inputDirectories.toArray(new String[inputDirectories.size()]));
TestUtils.removeFiles(inputRFiles.toArray(new String[inputRFiles.size()]));
@@ -1671,7 +1707,7 @@
TestUtils.clearAssertionInformation();
}
- public boolean bufferContainsString(ByteArrayOutputStream buffer, String str){
+ public boolean bufferContainsString(ByteArrayOutputStream buffer, String str) {
return Arrays.stream(buffer.toString().split("\n")).anyMatch(x -> x.contains(str));
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedTestObjectConstructor.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedTestObjectConstructor.java
new file mode 100644
index 0000000..af55b95
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedTestObjectConstructor.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated;
+
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.meta.MetaData;
+import org.junit.Assert;
+
+public class FederatedTestObjectConstructor {
+ public static MatrixObject constructFederatedInput(int rows, int cols, int blocksize, String host, long[][] begin,
+ long[][] end, int[] ports, String[] inputs, String file) {
+ MatrixObject fed = new MatrixObject(ValueType.FP64, file);
+ try {
+ fed.setMetaData(new MetaData(new MatrixCharacteristics(rows, cols, blocksize, rows * cols)));
+ List<Pair<FederatedRange, FederatedData>> d = new ArrayList<>();
+ for(int i = 0; i < ports.length; i++) {
+ FederatedRange X1r = new FederatedRange(begin[i], end[i]);
+ FederatedData X1d = new FederatedData(Types.DataType.MATRIX,
+ new InetSocketAddress(InetAddress.getByName(host), ports[i]), inputs[i]);
+ d.add(new ImmutablePair<>(X1r, X1d));
+ }
+
+ InitFEDInstruction.federateMatrix(fed, d);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ Assert.assertTrue(false);
+ }
+ return fed;
+
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
new file mode 100644
index 0000000..0f2b383
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.test.functions.federated.io;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.test.functions.federated.FederatedTestObjectConstructor;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedReaderTest extends AutomatedTestBase {
+
+ // private static final Log LOG = LogFactory.getLog(FederatedReaderTest.class.getName());
+ private final static String TEST_DIR = "functions/federated/io/";
+ private final static String TEST_NAME = "FederatedReaderTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedReaderTest.class.getSimpleName() + "/";
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public boolean rowPartitioned;
+ @Parameterized.Parameter(3)
+ public int fedCount;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // number of rows or cols has to be >= number of federated locations.
+ return Arrays.asList(new Object[][] {{10, 13, true, 2},});
+ }
+
+ @Test
+ public void federatedSinglenodeRead() {
+ federatedRead(Types.ExecMode.SINGLE_NODE);
+ }
+
+ public void federatedRead(Types.ExecMode execMode) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+ rtplatform = execMode;
+ if(rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int halfRows = rows / 2;
+ long[][] begins = new long[][] {new long[] {0, 0}, new long[] {halfRows, 0}};
+ long[][] ends = new long[][] {new long[] {halfRows, cols}, new long[] {rows, cols}};
+ // We have two matrices handled by a single federated worker
+ double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+ double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+ writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
+ String host = "localhost";
+
+ MatrixObject fed = FederatedTestObjectConstructor.constructFederatedInput(rows,
+ cols,
+ blocksize,
+ host,
+ begins,
+ ends,
+ new int[] {port1, port2},
+ new String[] {input("X1"), input("X2")},
+ input("X.json"));
+ writeInputFederatedWithMTD("X.json", fed, null);
+
+ try {
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + (rowPartitioned ? "Row" : "Col") + "Reference.dml";
+ programArgs = new String[] {"-args", input("X1"), input("X2")};
+ String refOut = runTest(null).toString();
+ // Run federated
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-args", input("X.json")};
+ String out = runTest(null).toString();
+ // Verify output
+ Assert.assertEquals(refOut.split("\n")[0], out.split("\n")[0]);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ Assert.assertTrue(false);
+ }
+
+ TestUtils.shutdownThreads(t1, t2);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+}
diff --git a/src/test/scripts/functions/federated/io/FederatedReaderTest.dml b/src/test/scripts/functions/federated/io/FederatedReaderTest.dml
new file mode 100644
index 0000000..0eb8683
--- /dev/null
+++ b/src/test/scripts/functions/federated/io/FederatedReaderTest.dml
@@ -0,0 +1,23 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1)
+print(sum(X))
diff --git a/src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml b/src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml
new file mode 100644
index 0000000..56c2316
--- /dev/null
+++ b/src/test/scripts/functions/federated/io/FederatedReaderTestColReference.dml
@@ -0,0 +1,23 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = cbind(read($1), read($2))
+print(sum(X))
diff --git a/src/test/scripts/functions/federated/io/FederatedReaderTestRowReference.dml b/src/test/scripts/functions/federated/io/FederatedReaderTestRowReference.dml
new file mode 100644
index 0000000..5059e4d
--- /dev/null
+++ b/src/test/scripts/functions/federated/io/FederatedReaderTestRowReference.dml
@@ -0,0 +1,23 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2))
+print(sum(X))