[SYSTEMDS-3900] Improved integration of OOC binary stream writer
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index e075b55..7517bf0 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -909,44 +909,58 @@
// a) get the matrix
boolean federatedWrite = (outputFormat != null ) && outputFormat.contains("federated");
- if( isEmpty(true) && !federatedWrite)
- {
- //read data from HDFS if required (never read before), this applies only to pWrite w/ different output formats
- //note: for large rdd outputs, we compile dedicated writespinstructions (no need to handle this here)
+ if(getStreamHandle()!=null) {
try {
- if( getRDDHandle()==null || getRDDHandle().allowsShortCircuitRead() )
- _data = readBlobFromHDFS( _hdfsFileName );
- else if( getRDDHandle() != null )
- _data = readBlobFromRDD( getRDDHandle(), new MutableBoolean() );
- else if(!federatedWrite)
- _data = readBlobFromFederated( getFedMapping() );
- setDirty(false);
- refreshMetaData(); //e.g., after unknown csv read
+ long totalNnz = writeStreamToHDFS(fName, outputFormat, replication, formatProperties);
+ updateDataCharacteristics(new MatrixCharacteristics(
+ getNumRows(), getNumColumns(), blen, totalNnz));
+ writeMetaData(fName, outputFormat, formatProperties);
}
- catch (IOException e) {
- throw new DMLRuntimeException("Reading of " + _hdfsFileName + " ("+hashCode()+") failed.", e);
+ catch(Exception ex) {
+ throw new DMLRuntimeException("Failed to write OOC stream to " + fName, ex);
}
}
- //get object from cache
- if(!federatedWrite) {
- if( _data == null )
- getCache();
- acquire( false, _data==null ); //incl. read matrix if evicted
- }
-
- // b) write the matrix
- try {
- writeMetaData( fName, outputFormat, formatProperties );
- writeBlobToHDFS( fName, outputFormat, replication, formatProperties );
- if ( !pWrite )
- setDirty(false);
- }
- catch (Exception e) {
- throw new DMLRuntimeException("Export to " + fName + " failed.", e);
- }
- finally {
- if(!federatedWrite)
- release();
+ else {
+ if( isEmpty(true) && !federatedWrite)
+ {
+ //read data from HDFS if required (never read before), this applies only to pWrite w/ different output formats
+ //note: for large rdd outputs, we compile dedicated writespinstructions (no need to handle this here)
+ try {
+ if( getRDDHandle()==null || getRDDHandle().allowsShortCircuitRead() )
+ _data = readBlobFromHDFS( _hdfsFileName );
+ else if( getRDDHandle() != null )
+ _data = readBlobFromRDD( getRDDHandle(), new MutableBoolean() );
+ else if(!federatedWrite)
+ _data = readBlobFromFederated( getFedMapping() );
+ setDirty(false);
+ refreshMetaData(); //e.g., after unknown csv read
+ }
+ catch (IOException e) {
+ throw new DMLRuntimeException("Reading of " + _hdfsFileName + " ("+hashCode()+") failed.", e);
+ }
+ }
+
+ //get object from cache
+ if(!federatedWrite) {
+ if( _data == null )
+ getCache();
+ acquire( false, _data==null ); //incl. read matrix if evicted
+ }
+
+ // b) write the matrix
+ try {
+ writeMetaData( fName, outputFormat, formatProperties );
+ writeBlobToHDFS( fName, outputFormat, replication, formatProperties );
+ if ( !pWrite )
+ setDirty(false);
+ }
+ catch (Exception e) {
+ throw new DMLRuntimeException("Export to " + fName + " failed.", e);
+ }
+ finally {
+ if(!federatedWrite)
+ release();
+ }
}
}
else if( pWrite ) // pwrite with same output format
@@ -1132,6 +1146,9 @@
protected abstract void writeBlobToHDFS(String fname, String ofmt, int rep, FileFormatProperties fprop)
throws IOException;
+ protected abstract long writeStreamToHDFS(String fname, String ofmt, int rep, FileFormatProperties fprop)
+ throws IOException;
+
protected abstract void writeBlobFromRDDtoHDFS(RDDObject rdd, String fname, String ofmt)
throws IOException;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
index 56cc276..f4d20bb 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
@@ -296,6 +296,14 @@
}
@Override
+ protected long writeStreamToHDFS(String fname, String ofmt, int rep, FileFormatProperties fprop)
+ throws IOException, DMLRuntimeException
+ {
+ throw new UnsupportedOperationException();
+ }
+
+
+ @Override
protected void writeBlobFromRDDtoHDFS(RDDObject rdd, String fname, String ofmt)
throws IOException, DMLRuntimeException
{
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index e9204bd..9f4ca12 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -47,6 +47,8 @@
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.io.FileFormatProperties;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
import org.apache.sysds.runtime.io.ReaderWriterFederated;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageRecomputeUtils;
@@ -601,6 +603,17 @@
if(DMLScript.STATISTICS)
CacheStatistics.incrementHDFSWrites();
}
+
+ @Override
+ protected long writeStreamToHDFS(String fname, String ofmt, int rep, FileFormatProperties fprop)
+ throws IOException, DMLRuntimeException
+ {
+ MetaDataFormat iimd = (MetaDataFormat) _metaData;
+ FileFormat fmt = (ofmt != null ? FileFormat.safeValueOf(ofmt) : iimd.getFileFormat());
+ MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(fmt, rep, fprop);
+ return writer.writeMatrixFromStream(fname, getStreamHandle(),
+ getNumRows(), getNumColumns(), ConfigurationManager.getBlocksize());
+ }
@Override
protected void writeBlobFromRDDtoHDFS(RDDObject rdd, String fname, String outputFormat)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java
index d39ed8c..d0111a3 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java
@@ -189,6 +189,13 @@
if( DMLScript.STATISTICS )
CacheStatistics.incrementHDFSWrites();
}
+
+ @Override
+ protected long writeStreamToHDFS(String fname, String ofmt, int rep, FileFormatProperties fprop)
+ throws IOException, DMLRuntimeException
+ {
+ throw new UnsupportedOperationException();
+ }
@Override
protected ValueType[] getSchema() {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
index bd40f25..d3be925 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
@@ -41,13 +41,11 @@
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
-import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.io.FileFormatProperties;
import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
import org.apache.sysds.runtime.io.FileFormatPropertiesHDF5;
@@ -57,8 +55,6 @@
import org.apache.sysds.runtime.io.WriterHDF5;
import org.apache.sysds.runtime.io.WriterMatrixMarket;
import org.apache.sysds.runtime.io.WriterTextCSV;
-import org.apache.sysds.runtime.io.MatrixWriterFactory;
-import org.apache.sysds.runtime.io.MatrixWriter;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.lineage.LineageTraceable;
@@ -1066,33 +1062,7 @@
else if( getInput1().getDataType() == DataType.MATRIX ) {
MatrixObject mo = ec.getMatrixObject(getInput1().getName());
int blen = Integer.parseInt(getInput4().getName());
- LocalTaskQueue<IndexedMatrixValue> stream = mo.getStreamHandle();
- if (stream != null) {
-
- try {
- MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(fmt);
- long nrows = mo.getNumRows();
- long ncols = mo.getNumColumns();
-
- long totalNnz = writer.writeMatrixFromStream(fname, stream, nrows, ncols, blen);
- MatrixCharacteristics mc = new MatrixCharacteristics(nrows, ncols, blen, totalNnz);
- HDFSTool.writeMetaDataFile(fname + ".mtd", mo.getValueType(), mc, fmt);
-
- // 1. Update the metadata of the MatrixObject in the symbol table.
- mo.updateDataCharacteristics(mc);
- System.out.println("MO characterstics updated to avoid recompilation");
-
- // 2. Clear its dirty flag and update its file path to the result we just wrote.
- // This tells the system that the data for this variable now lives in 'fname'.
- HDFSTool.copyFileOnHDFS(fname, mo.getFileName());
- mo.setDirty(false);
-
- }
- catch(Exception ex) {
- throw new DMLRuntimeException("Failed to write OOC stream to " + fname, ex);
- }
- }
if( fmt == FileFormat.MM )
writeMMFile(ec, fname);
else if( fmt == FileFormat.CSV )
diff --git a/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java b/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java
index a3798a4..82c994e 100644
--- a/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java
@@ -235,16 +235,13 @@
@Override
public long writeMatrixFromStream(String fname, LocalTaskQueue<IndexedMatrixValue> stream, long rlen, long clen, int blen) throws IOException {
- JobConf conf = ConfigurationManager.getCachedJobConf();
Path path = new Path(fname);
- FileSystem fs = IOUtilFunctions.getFileSystem(path, conf);
-
SequenceFile.Writer writer = null;
long totalNnz = 0;
try {
- // 1. Create Sequence file writer for the final destination file writer = new SequenceFile.Writer(fs, conf, path, MatrixIndexes.class, MatrixBlock.class);
- writer = SequenceFile.createWriter(fs, conf, path, MatrixIndexes.class, MatrixBlock.class);
+ // 1. Create Sequence file writer for the final destination file
+ writer = IOUtilFunctions.getSeqWriter(path, job, _replication);
// 2. Loop through OOC stream
IndexedMatrixValue i_val = null;
@@ -257,13 +254,12 @@
totalNnz += mb.getNonZeros();
}
-
- } catch (IOException | InterruptedException e) {
+ } catch (Exception e) {
throw new DMLRuntimeException(e);
- } finally {
+ } finally {
IOUtilFunctions.closeSilently(writer);
}
- return totalNnz;
+ return totalNnz;
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/UnaryTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/UnaryTest.java
index fc6f01a..5e203b5 100644
--- a/src/test/java/org/apache/sysds/test/functions/ooc/UnaryTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/UnaryTest.java
@@ -29,7 +29,6 @@
import org.apache.sysds.runtime.io.MatrixWriter;
import org.apache.sysds.runtime.io.MatrixWriterFactory;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.HDFSTool;
@@ -40,13 +39,10 @@
import org.junit.Test;
import java.io.IOException;
-import java.util.HashMap;
-
-import static org.apache.sysds.test.TestUtils.readDMLMatrixFromHDFS;
public class UnaryTest extends AutomatedTestBase {
- private static final String TEST_NAME = "Unary";
+ private static final String TEST_NAME = "UnaryWrite";
private static final String TEST_DIR = "functions/ooc/";
private static final String TEST_CLASS_DIR = TEST_DIR + UnaryTest.class.getSimpleName() + "/";
private static final String INPUT_NAME = "X";
@@ -55,18 +51,19 @@
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME);
- addTestConfiguration(TEST_NAME, config);
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
}
- /**
- * Test the sum of scalar multiplication, "sum(X*7)", with OOC backend.
- */
@Test
- public void testUnary() {
+ public void testWriteNoRewrite() {
testUnaryOperation(false);
}
+ @Test
+ public void testWriteRewrite() {
+ testUnaryOperation(true);
+ }
+
public void testUnaryOperation(boolean rewrite)
{
@@ -116,8 +113,9 @@
}
}
- private static double[][] readMatrix( String fname, FileFormat fmt, long rows, long cols, int brows, int bcols )
- throws IOException
+ private static double[][] readMatrix( String fname, FileFormat fmt,
+ long rows, long cols, int brows, int bcols )
+ throws IOException
{
MatrixBlock mb = DataConverter.readMatrixFromHDFS(fname, fmt, rows, cols, brows, bcols);
double[][] C = DataConverter.convertToDoubleMatrix(mb);
diff --git a/src/test/scripts/functions/ooc/Unary.dml b/src/test/scripts/functions/ooc/UnaryWrite.dml
similarity index 91%
rename from src/test/scripts/functions/ooc/Unary.dml
rename to src/test/scripts/functions/ooc/UnaryWrite.dml
index 24c0d98..da2d262 100644
--- a/src/test/scripts/functions/ooc/Unary.dml
+++ b/src/test/scripts/functions/ooc/UnaryWrite.dml
@@ -21,9 +21,5 @@
# Read input matrix and operator from command line args
X = read($1);
-#print(toString(X))
res = ceil(X);
-#print(toString(Y))
-#res = as.matrix(sum(Y));
-# Write the final matrix result
write(res, $2, format="binary");