[SYSTEMDS-2554,2558,2561] Federated transform decode (recoding)
Closes #1027.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index f7e893f..720534a 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -768,7 +768,7 @@
new Path(_hdfsFileName), new Path(fName));
//actual export (note: no direct transfer of local copy in order to ensure blocking (and hence, parallelism))
- if( isDirty() || !eqScheme ||
+ if( isDirty() || !eqScheme || isFederated() ||
(pWrite && !isEqualOutputFormat(outputFormat)) )
{
// CASE 1: dirty in-mem matrix or pWrite w/ different format (write matrix to fname; load into memory if evicted)
@@ -781,13 +781,15 @@
{
if( getRDDHandle()==null || getRDDHandle().allowsShortCircuitRead() )
_data = readBlobFromHDFS( _hdfsFileName );
- else
+ else if( getRDDHandle() != null )
_data = readBlobFromRDD( getRDDHandle(), new MutableBoolean() );
+ else
+ _data = readBlobFromFederated( getFedMapping() );
+
setDirty(false);
}
- catch (IOException e)
- {
- throw new DMLRuntimeException("Reading of " + _hdfsFileName + " ("+hashCode()+") failed.", e);
+ catch (IOException e) {
+ throw new DMLRuntimeException("Reading of " + _hdfsFileName + " ("+hashCode()+") failed.", e);
}
}
//get object from cache
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index d161522..9f5f942 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -119,6 +119,7 @@
/**
* Executes an federated operation on a federated worker.
*
+ * @param address socket address (incl host and port)
* @param request the requested operation
* @return the response
*/
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 6dd0abc..f4af303 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -256,7 +256,8 @@
pb.execute(ec); //execute single instruction
}
catch(Exception ex) {
- return new FederatedResponse(ResponseType.ERROR, ex.getMessage());
+ return new FederatedResponse(ResponseType.ERROR, new FederatedWorkerHandlerException(
+ "Exception of type " + ex.getClass() + " thrown when processing EXEC_INST request", ex));
}
return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
}
@@ -276,12 +277,19 @@
return udf.execute(ec, inputs);
}
catch(Exception ex) {
- return new FederatedResponse(ResponseType.ERROR, ex.getMessage());
+ return new FederatedResponse(ResponseType.ERROR, new FederatedWorkerHandlerException(
+ "Exception of type " + ex.getClass() + " thrown when processing EXEC_UDF request", ex));
}
}
private FederatedResponse execClear() {
- _ecm.clear();
+ try {
+ _ecm.clear();
+ }
+ catch(Exception ex) {
+ return new FederatedResponse(ResponseType.ERROR, new FederatedWorkerHandlerException(
+ "Exception of type " + ex.getClass() + " thrown when processing CLEAR request", ex));
+ }
return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index 5e62475..cfb20e3 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -445,7 +445,7 @@
}
else if (opcode.equalsIgnoreCase("transformdecode")) {
CPOperand target = getTargetOperand();
- CPOperand meta = getLiteral(params.get("meta"), DataType.FRAME);
+ CPOperand meta = getLiteral("meta", ValueType.UNKNOWN, DataType.FRAME);
CPOperand spec = getStringLiteral("spec");
return Pair.of(output.getName(), new LineageItem(getOpcode(),
LineageItemUtils.getLineage(ec, target, meta, spec)));
@@ -476,12 +476,12 @@
private CPOperand getBoolLiteral(String name) {
return getLiteral(name, ValueType.BOOLEAN);
}
-
- private CPOperand getLiteral(String name, DataType dt) {
- return new CPOperand(name, ValueType.UNKNOWN, DataType.FRAME);
- }
private CPOperand getLiteral(String name, ValueType vt) {
return new CPOperand(params.get(name), vt, DataType.SCALAR, true);
}
+
+ private CPOperand getLiteral(String name, ValueType vt, DataType dt) {
+ return new CPOperand(params.get(name), vt, dt);
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index c8cf729..a1b0a08 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -75,10 +75,13 @@
}
}
else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
- ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction)inst;
+ ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction) inst;
if(pinst.getOpcode().equals("replace") && pinst.getTarget(ec).isFederated()) {
fedinst = ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
}
+ else if(pinst.getOpcode().equals("transformdecode") && pinst.getTarget(ec).isFederated()) {
+ return ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
+ }
}
else if (inst instanceof MultiReturnParameterizedBuiltinCPInstruction) {
MultiReturnParameterizedBuiltinCPInstruction minst = (MultiReturnParameterizedBuiltinCPInstruction) inst;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index ec28965..e3523ed 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -19,103 +19,213 @@
package org.apache.sysds.runtime.instructions.fed;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
+
+import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
-import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
+import org.apache.sysds.runtime.privacy.PrivacyMonitor;
+import org.apache.sysds.runtime.transform.decode.Decoder;
+import org.apache.sysds.runtime.transform.decode.DecoderFactory;
public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstruction {
-
protected final LinkedHashMap<String, String> params;
-
- protected ParameterizedBuiltinFEDInstruction(Operator op,
- LinkedHashMap<String, String> paramsMap, CPOperand out, String opcode, String istr)
- {
+
+ protected ParameterizedBuiltinFEDInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out,
+ String opcode, String istr) {
super(FEDType.ParameterizedBuiltin, op, null, null, out, opcode, istr);
params = paramsMap;
}
-
- public HashMap<String,String> getParameterMap() {
- return params;
+
+ public HashMap<String, String> getParameterMap() {
+ return params;
}
-
+
public String getParam(String key) {
return getParameterMap().get(key);
}
-
+
public static LinkedHashMap<String, String> constructParameterMap(String[] params) {
// process all elements in "params" except first(opcode) and last(output)
- LinkedHashMap<String,String> paramMap = new LinkedHashMap<>();
-
+ LinkedHashMap<String, String> paramMap = new LinkedHashMap<>();
+
// all parameters are of form <name=value>
String[] parts;
- for ( int i=1; i <= params.length-2; i++ ) {
+ for(int i = 1; i <= params.length - 2; i++) {
parts = params[i].split(Lop.NAME_VALUE_SEPARATOR);
paramMap.put(parts[0], parts[1]);
}
-
+
return paramMap;
}
-
- public static ParameterizedBuiltinFEDInstruction parseInstruction ( String str ) {
+
+ public static ParameterizedBuiltinFEDInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
// first part is always the opcode
String opcode = parts[0];
// last part is always the output
- CPOperand out = new CPOperand( parts[parts.length-1] );
-
+ CPOperand out = new CPOperand(parts[parts.length - 1]);
+
// process remaining parts and build a hash map
- LinkedHashMap<String,String> paramsMap = constructParameterMap(parts);
-
+ LinkedHashMap<String, String> paramsMap = constructParameterMap(parts);
+
// determine the appropriate value function
- ValueFunction func = null;
if( opcode.equalsIgnoreCase("replace") ) {
- func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
+ ValueFunction func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
return new ParameterizedBuiltinFEDInstruction(new SimpleOperator(func), paramsMap, out, opcode, str);
}
+ else if(opcode.equals("transformapply") || opcode.equals("transformdecode")) {
+ return new ParameterizedBuiltinFEDInstruction(null, paramsMap, out, opcode, str);
+ }
else {
throw new DMLRuntimeException("Unsupported opcode (" + opcode + ") for ParameterizedBuiltinFEDInstruction.");
}
}
-
- @Override
+
+ @Override
public void processInstruction(ExecutionContext ec) {
String opcode = getOpcode();
- if ( opcode.equalsIgnoreCase("replace") ) {
- //similar to unary federated instructions, get federated input
- //execute instruction, and derive federated output matrix
+ if(opcode.equalsIgnoreCase("replace")) {
+ // similar to unary federated instructions, get federated input
+ // execute instruction, and derive federated output matrix
MatrixObject mo = getTarget(ec);
FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
- new CPOperand[]{getTargetOperand()}, new long[]{mo.getFedMapping().getID()});
+ new CPOperand[] {getTargetOperand()}, new long[] {mo.getFedMapping().getID()});
mo.getFedMapping().execute(getTID(), true, fr1);
-
- //derive new fed mapping for output
+
+ // derive new fed mapping for output
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo.getDataCharacteristics());
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
}
+ else if(opcode.equalsIgnoreCase("transformdecode"))
+ transformDecode(ec);
else {
throw new DMLRuntimeException("Unknown opcode : " + opcode);
}
}
+ private void transformDecode(ExecutionContext ec) {
+ // acquire locks
+ MatrixObject mo = ec.getMatrixObject(params.get("target"));
+ FrameBlock meta = ec.getFrameInput(params.get("meta"));
+ String spec = params.get("spec");
+
+ FederationMap fedMapping = mo.getFedMapping();
+
+ ValueType[] schema = new ValueType[(int) mo.getNumColumns()];
+ long varID = FederationUtils.getNextFedDataID();
+ FederationMap decodedMapping = fedMapping.mapParallel(varID, (range, data) -> {
+ int columnOffset = (int) range.getBeginDims()[1] + 1;
+
+ FrameBlock subMeta = new FrameBlock();
+ synchronized(meta) {
+ meta.slice(0, meta.getNumRows() - 1, columnOffset - 1, (int) range.getEndDims()[1] - 1, subMeta);
+ }
+
+ FederatedResponse response;
+ try {
+ response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+ varID, new DecodeMatrix(data.getVarID(), varID, subMeta, spec, columnOffset))).get();
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+
+ ValueType[] subSchema = (ValueType[]) response.getData()[0];
+ synchronized(schema) {
+ // It would be possible to assert that different federated workers don't give different value
+ // types for the same columns, but the performance impact is not worth the effort
+ System.arraycopy(subSchema, 0, schema, columnOffset - 1, subSchema.length);
+ }
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+
+ // construct a federated matrix with the encoded data
+ FrameObject decodedFrame = ec.getFrameObject(output);
+ decodedFrame.setSchema(schema);
+ decodedFrame.getDataCharacteristics().set(mo.getDataCharacteristics());
+ // set the federated mapping for the matrix
+ decodedFrame.setFedMapping(decodedMapping);
+
+ // release locks
+ ec.releaseFrameInput(params.get("meta"));
+ }
+
public MatrixObject getTarget(ExecutionContext ec) {
return ec.getMatrixObject(params.get("target"));
}
-
+
private CPOperand getTargetOperand() {
return new CPOperand(params.get("target"), ValueType.FP64, DataType.MATRIX);
}
+
+ public static class DecodeMatrix extends FederatedUDF {
+ private static final long serialVersionUID = 2376756757742169692L;
+ private final long _outputID;
+ private final FrameBlock _meta;
+ private final String _spec;
+ private final int _globalOffset;
+
+ public DecodeMatrix(long input, long outputID, FrameBlock meta, String spec, int globalOffset) {
+ super(new long[]{input});
+ _outputID = outputID;
+ _meta = meta;
+ _spec = spec;
+ _globalOffset = globalOffset;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ MatrixObject mo = (MatrixObject) PrivacyMonitor.handlePrivacy(data[0]);
+ MatrixBlock mb = mo.acquireRead();
+ String[] colNames = _meta.getColumnNames();
+
+ // compute transformdecode
+ Decoder decoder = DecoderFactory.createDecoder(_spec, colNames, null,
+ _meta, mb.getNumColumns(), _globalOffset, _globalOffset + mb.getNumColumns());
+ FrameBlock fbout = decoder.decode(mb, new FrameBlock(decoder.getSchema()));
+ fbout.setColumnNames(Arrays.copyOfRange(colNames, 0, fbout.getNumColumns()));
+
+ // copy characteristics
+ MatrixCharacteristics mc = new MatrixCharacteristics(mo.getDataCharacteristics());
+ FrameObject fo = new FrameObject(OptimizerUtils.getUniqueTempFileName(),
+ new MetaDataFormat(mc, Types.FileFormat.BINARY));
+ // set the encoded data
+ fo.acquireModify(fbout);
+ fo.release();
+ mo.release();
+
+ // add it to the list of variables
+ ec.setVariable(String.valueOf(_outputID), fo);
+ // return schema
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[] {fo.getSchema()});
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
index 977d494..b51547d 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
@@ -37,24 +37,33 @@
public class DecoderFactory
{
public static Decoder createDecoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta) {
- return createDecoder(spec, colnames, schema, meta, meta.getNumColumns());
+ return createDecoder(spec, colnames, schema, meta, meta.getNumColumns(), -1, -1);
}
- public static Decoder createDecoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta, int clen)
+ public static Decoder createDecoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta, int clen) {
+ return createDecoder(spec, colnames, schema, meta, clen, -1, -1);
+ }
+
+ public static Decoder createDecoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta, int minCol,
+ int maxCol) {
+ return createDecoder(spec, colnames, schema, meta, meta.getNumColumns(), minCol, maxCol);
+ }
+
+ public static Decoder createDecoder(String spec, String[] colnames, ValueType[] schema,
+ FrameBlock meta, int clen, int minCol, int maxCol)
{
Decoder decoder = null;
- try
- {
+ try {
//parse transform specification
JSONObject jSpec = new JSONObject(spec);
List<Decoder> ldecoders = new ArrayList<>();
//create decoders 'recode', 'dummy' and 'pass-through'
List<Integer> rcIDs = Arrays.asList(ArrayUtils.toObject(
- TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.RECODE.toString())));
+ TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.RECODE.toString(), minCol, maxCol)));
List<Integer> dcIDs = Arrays.asList(ArrayUtils.toObject(
- TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString())));
+ TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol)));
rcIDs = unionDistinct(rcIDs, dcIDs);
int len = dcIDs.isEmpty() ? Math.min(meta.getNumColumns(), clen) : meta.getNumColumns();
List<Integer> ptIDs = except(UtilFunctions.getSeqList(1, len, 1), rcIDs);
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
index 4c0ad9e..5945e27 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
@@ -134,7 +134,7 @@
*/
public Encoder subRangeEncoder(int colStart, int colEnd) {
throw new DMLRuntimeException(
- this.getClass().getName() + " does not support the creation of a sub-range encoder");
+ this.getClass().getSimpleName() + " does not support the creation of a sub-range encoder");
}
/**
@@ -145,7 +145,7 @@
*/
protected void mergeColumnInfo(Encoder other, int col) {
// update number of columns
- _clen = Math.max(_colList.length, col - 1 + other.getNumCols());
+ _clen = Math.max(_clen, col - 1 + other._clen);
// update the new columns that this encoder operates on
Set<Integer> colListAgg = new HashSet<>(); // for dedup
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index dcd2b1c..2070485 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -85,7 +85,7 @@
//create individual encoders
if( !rcIDs.isEmpty() ) {
- EncoderRecode ra = new EncoderRecode(jSpec, colnames, clen);
+ EncoderRecode ra = new EncoderRecode(jSpec, colnames, clen, minCol, maxCol);
ra.setColList(ArrayUtils.toPrimitive(rcIDs.toArray(new Integer[0])));
lencoders.add(ra);
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
index fe3f5b1..d4b201e 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
@@ -43,11 +43,11 @@
private HashMap<Integer, HashMap<String, Long>> _rcdMaps = new HashMap<>();
private HashMap<Integer, HashSet<Object>> _rcdMapsPart = null;
- public EncoderRecode(JSONObject parsedSpec, String[] colnames, int clen)
+ public EncoderRecode(JSONObject parsedSpec, String[] colnames, int clen, int minCol, int maxCol)
throws JSONException
{
super(null, clen);
- _colList = TfMetaUtils.parseJsonIDList(parsedSpec, colnames, TfMethod.RECODE.toString());
+ _colList = TfMetaUtils.parseJsonIDList(parsedSpec, colnames, TfMethod.RECODE.toString(), minCol, maxCol);
}
private EncoderRecode(int[] colList, int clen) {
@@ -58,6 +58,11 @@
this(new int[0], 0);
}
+ private EncoderRecode(int[] colList, int clen, HashMap<Integer, HashMap<String, Long>> rcdMaps) {
+ super(colList, clen);
+ _rcdMaps = rcdMaps;
+ }
+
public HashMap<Integer, HashMap<String,Long>> getCPRecodeMaps() {
return _rcdMaps;
}
@@ -180,9 +185,7 @@
return null;
int[] colList = cols.stream().mapToInt(i -> i).toArray();
- EncoderRecode subRangeEncoder = new EncoderRecode(colList, colEnd - colStart);
- subRangeEncoder._rcdMaps = rcdMaps;
- return subRangeEncoder;
+ return new EncoderRecode(colList, colEnd - colStart, rcdMaps);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
index 39f5650..72fab7a 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
@@ -135,7 +135,7 @@
}
if(ix <= 0) {
if (minCol == -1 && maxCol == -1) {
- // only if we cut of some columns, ix -1 is expected
+ // only if we remove some columns, ix -1 is expected
throw new RuntimeException("Specified column '"
+ attrs.get(i)+"' does not exist.");
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedNegativeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedNegativeTest.java
index a355275..8c60cec 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedNegativeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedNegativeTest.java
@@ -19,8 +19,6 @@
package org.apache.sysds.test.functions.federated;
-import org.apache.log4j.Level;
-import org.apache.log4j.Logger;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.controlprogram.federated.*;
import org.apache.sysds.test.AutomatedTestBase;
@@ -30,7 +28,6 @@
import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.Map;
-import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import static org.junit.Assert.assertFalse;
@@ -38,23 +35,16 @@
@net.jcip.annotations.NotThreadSafe
public class FederatedNegativeTest {
- protected static Logger log = Logger.getLogger(FederatedNegativeTest.class);
-
- static {
- Logger.getLogger("org.apache.sysds").setLevel(Level.OFF);
- }
-
@Test
public void NegativeTest1() {
- int port = AutomatedTestBase.getRandomAvailablePort();
+ int port = AutomatedTestBase.getRandomAvailablePort();
String[] args = {"-w", Integer.toString(port)};
- Thread t = AutomatedTestBase.startLocalFedWorkerWithArgs(args);
+ Thread t = AutomatedTestBase.startLocalFedWorkerWithArgs(args);
+ FederationUtils.resetFedDataID(); //ensure expected ID when tests run in single JVM
Map<FederatedRange, FederatedData> fedMap = new HashMap<>();
FederatedRange r = new FederatedRange(new long[]{0,0}, new long[]{1,1});
- FederatedData d = new FederatedData(
- Types.DataType.SCALAR,
- new InetSocketAddress("localhost", port),
- "Nowhere");
+ FederatedData d = new FederatedData(Types.DataType.SCALAR,
+ new InetSocketAddress("localhost", port), "Nowhere");
fedMap.put(r,d);
FederationMap fedM = new FederationMap(fedMap);
FederatedRequest fr = new FederatedRequest(FederatedRequest.RequestType.GET_VAR);
@@ -62,17 +52,11 @@
try {
FederatedResponse fres = res[0].get();
assertFalse(fres.isSuccessful());
- assertTrue(fres.getErrorMessage().contains("Variable 0 does not exist at federated worker"));
-
- } catch (InterruptedException e) {
- e.printStackTrace();
- } catch (ExecutionException e) {
- e.printStackTrace();
- } catch (Exception e) {
+ assertTrue(fres.getErrorMessage().contains("Variable 1 does not exist at federated worker"));
+ }
+ catch (Exception e) {
e.printStackTrace();
}
-
TestUtils.shutdownThread(t);
}
-
}
diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformFederatedEncodeDecodeTest.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformFederatedEncodeDecodeTest.java
index df9def8..1f9c87d 100644
--- a/src/test/java/org/apache/sysds/test/functions/transform/TransformFederatedEncodeDecodeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformFederatedEncodeDecodeTest.java
@@ -49,7 +49,8 @@
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"FO"}));
+ addTestConfiguration(TEST_NAME1,
+ new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"FO1", "FO2"}));
}
@Test
@@ -126,14 +127,26 @@
"in_AL=" + TestUtils.federatedAddress("localhost", port2, input("AL")),
"in_BU=" + TestUtils.federatedAddress("localhost", port3, input("BU")),
"in_BL=" + TestUtils.federatedAddress("localhost", port4, input("BL")), "rows=" + rows, "cols=" + cols,
- "spec_file=" + SCRIPT_DIR + TEST_DIR + SPEC, "out=" + output("FO"), "format=" + format.toString()};
+ "spec_file=" + SCRIPT_DIR + TEST_DIR + SPEC, "out1=" + output("FO1"), "out2=" + output("FO2"),
+ "format=" + format.toString()};
// run test
runTest(true, false, null, -1);
- // compare matrices (values recoded to identical codes)
+ // compare frame before and after encode and decode
FrameReader reader = FrameReaderFactory.createFrameReader(format);
- FrameBlock FO = reader.readFrameFromHDFS(output("FO"), 15, 2);
+ FrameBlock OUT = reader.readFrameFromHDFS(output("FO2"), rows, cols);
+ for(int r = 0; r < rows; r++) {
+ for(int c = 0; c < cols; c++) {
+ String expected = c < cols / 2 ? Double.toString(A[r][c]) : "Str" + B[r][c - cols / 2];
+ String val = (String) OUT.get(r, c);
+ Assert.assertEquals("Enc- and Decoded frame does not match the source frame: " + expected + " vs "
+ + val, expected, val);
+ }
+ }
+ // TODO federate the aggregated result so that the decode is applied in a federated environment
+ // compare matrices (values recoded to identical codes)
+ FrameBlock FO = reader.readFrameFromHDFS(output("FO1"), 15, 2);
HashMap<String, Long> cFA = getCounts(A, B);
Iterator<String[]> iterFO = FO.getStringRowIterator();
while(iterFO.hasNext()) {
diff --git a/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml b/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml
index 1ff5446..50174d7 100644
--- a/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml
+++ b/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml
@@ -19,19 +19,20 @@
#
#-------------------------------------------------------------
-F1 = federated(type="frame", addresses=list($in_AU, $in_AL, $in_BU, $in_BL), ranges=
+F = federated(type="frame", addresses=list($in_AU, $in_AL, $in_BU, $in_BL), ranges=
list(list(0,0), list($rows / 2, $cols / 2), # AUpper range
list($rows / 2, 0), list($rows, $cols / 2), # ALower range
list(0, $cols / 2), list($rows / 2, $cols), # BUpper range
list($rows / 2, $cols / 2), list($rows, $cols))); # BLower range
jspec = read($spec_file, data_type="scalar", value_type="string");
-[X, M] = transformencode(target=F1, spec=jspec);
+[X, M] = transformencode(target=F, spec=jspec);
A = aggregate(target=X[,1], groups=X[,2], fn="count");
Ag = cbind(A, seq(1,nrow(A)));
-F2 = transformdecode(target=Ag, spec=jspec, meta=M);
+FO1 = transformdecode(target=Ag, spec=jspec, meta=M);
+FO2 = transformdecode(target=X, spec=jspec, meta=M);
-write(F2, $out, format=$format);
-
+write(FO1, $out1, format=$format);
+write(FO2, $out2, format=$format);