[SYSTEMDS-2554,2558,2561] Initial federated transform encode
Encoders: recode, pass-through, composite
Closes #966.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 1d5f5df..2c5f902 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -76,6 +76,10 @@
_varID = varID;
}
+ public long getVarID() {
+ return _varID;
+ }
+
public String getFilepath() {
return _filepath;
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index f2d53e4..24be89f 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -34,7 +34,9 @@
READ_VAR, // create variable for local data, read on first access
PUT_VAR, // receive data from main and store to local variable
GET_VAR, // return local variable to main
- EXEC_INST // execute arbitrary instruction over
+ EXEC_INST, // execute arbitrary instruction over
+ FRAME_ENCODE, // TODO replace with user defined functions
+ CREATE_ENCODER
}
private RequestType _method;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index b7bbafe..e2332e2 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -23,6 +23,7 @@
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
+
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.log4j.Logger;
@@ -30,6 +31,7 @@
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
@@ -44,12 +46,17 @@
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.privacy.DMLPrivacyException;
import org.apache.sysds.runtime.privacy.PrivacyMonitor;
import org.apache.sysds.runtime.privacy.PrivacyPropagator;
+import org.apache.sysds.runtime.transform.encode.Encoder;
+import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.utils.JSONHelper;
import org.apache.wink.json4j.JSONObject;
@@ -108,6 +115,10 @@
return getVariable(request);
case EXEC_INST:
return execInstruction(request);
+ case CREATE_ENCODER:
+ return createFrameEncoder(request);
+ case FRAME_ENCODE:
+ return executeFrameEncode(request);
default:
String message = String.format("Method %s is not supported.", method);
return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new FederatedWorkerHandlerException(message));
@@ -117,12 +128,64 @@
return new FederatedResponse(FederatedResponse.ResponseType.ERROR, ex);
}
catch (Exception ex) {
- return new FederatedResponse(FederatedResponse.ResponseType.ERROR,
- new FederatedWorkerHandlerException("Exception of type "
+ return new FederatedResponse(FederatedResponse.ResponseType.ERROR,
+ new FederatedWorkerHandlerException("Exception of type "
+ ex.getClass() + " thrown when processing request", ex));
}
}
+ private FederatedResponse createFrameEncoder(FederatedRequest request) {
+ // param parsing
+ checkNumParams(request.getNumParams(), 2);
+ String spec = (String) request.getParam(0);
+ int globalOffset = (int) request.getParam(1);
+ long varID = request.getID();
+
+ Data dataObject = _ec.getVariable(String.valueOf(varID));
+ FrameObject fo = (FrameObject) PrivacyMonitor.handlePrivacy(dataObject);
+ FrameBlock data = fo.acquireRead();
+ String[] colNames = data.getColumnNames();
+
+ // create the encoder
+ Encoder encoder = EncoderFactory.createEncoder(spec, colNames,
+ data.getNumColumns(), null, globalOffset, globalOffset + data.getNumColumns());
+ // build necessary structures for encoding
+ encoder.build(data);
+ // otherwise data of FrameBlock would be null, therefore it would fail
+ // hack because serialization of FrameBlock does not function if Arrays are not allocated
+ fo.release();
+
+ return new FederatedResponse(ResponseType.SUCCESS, encoder);
+ }
+
+ private FederatedResponse executeFrameEncode(FederatedRequest request) {
+ checkNumParams(request.getNumParams(), 2);
+ Encoder encoder = (Encoder) request.getParam(0);
+ long newVarID = (long) request.getParam(1);
+ long varID = request.getID();
+
+ Data dataObject = _ec.getVariable(String.valueOf(varID));
+ FrameObject fo = (FrameObject) PrivacyMonitor.handlePrivacy(dataObject);
+ FrameBlock data = fo.acquireRead();
+
+ // apply transformation
+ MatrixBlock mbout = encoder.apply(data, new MatrixBlock(data.getNumRows(), data.getNumColumns(), false));
+
+ // copy characteristics
+ MatrixCharacteristics mc = new MatrixCharacteristics(fo.getDataCharacteristics());
+ MatrixObject mo = new MatrixObject(Types.ValueType.FP64, OptimizerUtils.getUniqueTempFileName(),
+ new MetaDataFormat(mc, FileFormat.BINARY));
+ // set the encoded data
+ mo.acquireModify(mbout);
+ mo.release();
+ fo.release();
+
+ // add it to the list of variables
+ _ec.setVariable(String.valueOf(newVarID), mo);
+ // return id handle
+ return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
+ }
+
private FederatedResponse readData(FederatedRequest request) {
checkNumParams(request.getNumParams(), 2);
String filename = (String) request.getParam(0);
@@ -143,7 +206,7 @@
break;
default:
// should NEVER happen (if we keep request codes in sync with actual behaviour)
- return new FederatedResponse(FederatedResponse.ResponseType.ERROR,
+ return new FederatedResponse(FederatedResponse.ResponseType.ERROR,
new FederatedWorkerHandlerException("Could not recognize datatype"));
}
@@ -168,6 +231,9 @@
throw new DMLRuntimeException(ex);
}
cd.setMetaData(new MetaDataFormat(mc, fmt));
+ // TODO send FileFormatProperties with request and use them for CSV, this is currently a workaround so reading
+ // of CSV files works
+ cd.setFileFormatProperties(new FileFormatPropertiesCSV());
cd.acquireRead();
cd.refreshMetaData(); //in pinned state
cd.release();
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index d2e2300..f224da2 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -24,7 +24,10 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.TreeMap;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
+import java.util.function.BiFunction;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
@@ -35,6 +38,7 @@
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.CommonThreadPool;
public class FederationMap
{
@@ -150,4 +154,64 @@
}
return this;
}
+
+ /**
+ * Execute a function for each <code>FederatedRange</code> + <code>FederatedData</code> pair. The function should
+ * not change any data of the pair and instead use <code>mapParallel</code> if that is a necessity. Note that this
+ * operation is parallel and necessary synchronisation has to be performed.
+ *
+ * @param forEachFunction function to execute for each pair
+ */
+ public void forEachParallel(BiFunction<FederatedRange, FederatedData, Void> forEachFunction) {
+ ExecutorService pool = CommonThreadPool.get(_fedMap.size());
+
+ ArrayList<MappingTask> mappingTasks = new ArrayList<>();
+ for(Map.Entry<FederatedRange, FederatedData> fedMap : _fedMap.entrySet())
+ mappingTasks.add(new MappingTask(fedMap.getKey(), fedMap.getValue(), forEachFunction, _ID));
+ CommonThreadPool.invokeAndShutdown(pool, mappingTasks);
+ }
+
+ /**
+ * Execute a function for each <code>FederatedRange</code> + <code>FederatedData</code> pair mapping the pairs to
+ * their new form by directly changing both <code>FederatedRange</code> and <code>FederatedData</code>. The varIDs
+ * don't have to be changed by the <code>mappingFunction</code> as that will be done by this method. Note that this
+ * operation is parallel and necessary synchronisation has to be performed.
+ *
+ * @param newVarID the new varID to be used by the new FederationMap
+ * @param mappingFunction the function directly changing ranges and data
+ * @return the new <code>FederationMap</code>
+ */
+ public FederationMap mapParallel(long newVarID, BiFunction<FederatedRange, FederatedData, Void> mappingFunction) {
+ ExecutorService pool = CommonThreadPool.get(_fedMap.size());
+
+ FederationMap fedMapCopy = copyWithNewID(_ID);
+ ArrayList<MappingTask> mappingTasks = new ArrayList<>();
+ for(Map.Entry<FederatedRange, FederatedData> fedMap : fedMapCopy._fedMap.entrySet())
+ mappingTasks.add(new MappingTask(fedMap.getKey(), fedMap.getValue(), mappingFunction, newVarID));
+ CommonThreadPool.invokeAndShutdown(pool, mappingTasks);
+ fedMapCopy._ID = newVarID;
+ return fedMapCopy;
+ }
+
+ private static class MappingTask implements Callable<Void> {
+ private final FederatedRange _range;
+ private final FederatedData _data;
+ private final BiFunction<FederatedRange, FederatedData, Void> _mappingFunction;
+ private final long _varID;
+
+ public MappingTask(FederatedRange range, FederatedData data,
+ BiFunction<FederatedRange, FederatedData, Void> mappingFunction, long varID) {
+ _range = range;
+ _data = data;
+ _mappingFunction = mappingFunction;
+ _varID = varID;
+ }
+
+ @Override
+ public Void call() throws Exception {
+ _mappingFunction.apply(_range, _data);
+ _data.setVarID(_varID);
+ return null;
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
index 97c6e1e..3afb681 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
@@ -93,6 +93,10 @@
public boolean isMatrix() {
return _dataType.isMatrix();
}
+
+ public boolean isFrame() {
+ return _dataType.isFrame();
+ }
public boolean isTensor() {
return _dataType.isTensor();
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index fc064eb..f2d0791 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -27,11 +27,12 @@
public abstract class FEDInstruction extends Instruction {
public enum FEDType {
- Init,
AggregateBinary,
AggregateUnary,
Append,
- Binary
+ Binary,
+ Init,
+ MultiReturnParameterizedBuiltin
}
protected final FEDType _fedType;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index c6927df..d639baa 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.instructions.fed;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
@@ -64,6 +65,17 @@
return BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
}
}
+ else if (inst instanceof MultiReturnParameterizedBuiltinCPInstruction) {
+ MultiReturnParameterizedBuiltinCPInstruction instruction = (MultiReturnParameterizedBuiltinCPInstruction) inst;
+ String opcode = instruction.getOpcode();
+ if(opcode.equals("transformencode") && instruction.input1.isFrame()) {
+ CacheableData<?> fo = ec.getCacheableData(instruction.input1);
+ if(fo.isFederated()) {
+ return MultiReturnParameterizedBuiltinFEDInstruction
+ .parseInstruction(instruction.getInstructionString());
+ }
+ }
+ }
return inst;
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
new file mode 100644
index 0000000..0aad335
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.concurrent.Future;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.transform.encode.Encoder;
+import org.apache.sysds.runtime.transform.encode.EncoderComposite;
+import org.apache.sysds.runtime.transform.encode.EncoderPassThrough;
+import org.apache.sysds.runtime.transform.encode.EncoderRecode;
+
+public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFEDInstruction {
+ protected final ArrayList<CPOperand> _outputs;
+
+ private MultiReturnParameterizedBuiltinFEDInstruction(Operator op, CPOperand input1, CPOperand input2,
+ ArrayList<CPOperand> outputs, String opcode, String istr) {
+ super(FEDType.MultiReturnParameterizedBuiltin, op, input1, input2, null, opcode, istr);
+ _outputs = outputs;
+ }
+
+ public CPOperand getOutput(int i) {
+ return _outputs.get(i);
+ }
+
+ public static MultiReturnParameterizedBuiltinFEDInstruction parseInstruction(String str) {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ ArrayList<CPOperand> outputs = new ArrayList<>();
+ String opcode = parts[0];
+
+ if(opcode.equalsIgnoreCase("transformencode")) {
+ // one input and two outputs
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ outputs.add(new CPOperand(parts[3], Types.ValueType.FP64, Types.DataType.MATRIX));
+ outputs.add(new CPOperand(parts[4], Types.ValueType.STRING, Types.DataType.FRAME));
+ return new MultiReturnParameterizedBuiltinFEDInstruction(null, in1, in2, outputs, opcode, str);
+ }
+ else {
+ throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + opcode);
+ }
+
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ // obtain and pin input frame
+ FrameObject fin = ec.getFrameObject(input1.getName());
+ String spec = ec.getScalarInput(input2).getStringValue();
+
+ // the encoder in which the complete encoding information will be aggregated
+ EncoderComposite globalEncoder = new EncoderComposite(
+ Arrays.asList(new EncoderRecode(), new EncoderPassThrough()));
+ // first create encoders at the federated workers, then collect them and aggregate them to a single large
+ // encoder
+ FederationMap fedMapping = fin.getFedMapping();
+ fedMapping.forEachParallel((range, data) -> {
+ int columnOffset = (int) range.getBeginDims()[1] + 1;
+
+ // create an encoder with the given spec. The columnOffset (which is 1 based) has to be used to
+ // tell the federated worker how much the indexes in the spec have to be offset.
+ Future<FederatedResponse> response = data.executeFederatedOperation(
+ new FederatedRequest(RequestType.CREATE_ENCODER, data.getVarID(), spec, columnOffset));
+ // collect responses with encoders
+ try {
+ Encoder encoder = (Encoder) response.get().getData()[0];
+ // merge this encoder into a composite encoder
+ synchronized(globalEncoder) {
+ globalEncoder.mergeAt(encoder, columnOffset);
+ }
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException("Federated encoder creation failed: " + e.getMessage());
+ }
+ return null;
+ });
+ long varID = FederationUtils.getNextFedDataID();
+ FederationMap transformedFedMapping = fedMapping.mapParallel(varID, (range, data) -> {
+ int colStart = (int) range.getBeginDims()[1] + 1;
+ int colEnd = (int) range.getEndDims()[1] + 1;
+ // get the encoder segment that is relevant for this federated worker
+ Encoder encoder = globalEncoder.subRangeEncoder(colStart, colEnd);
+ try {
+ FederatedResponse response = data.executeFederatedOperation(
+ new FederatedRequest(RequestType.FRAME_ENCODE, data.getVarID(), encoder, varID)).get();
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+
+ // construct a federated matrix with the encoded data
+ MatrixObject transformedMat = ec.getMatrixObject(getOutput(0));
+ transformedMat.getDataCharacteristics().set(fin.getDataCharacteristics());
+ // set the federated mapping for the matrix
+ transformedMat.setFedMapping(transformedFedMapping);
+
+ // release input and outputs
+ ec.setFrameOutput(getOutput(1).getName(),
+ globalEncoder.getMetaData(new FrameBlock(globalEncoder.getNumCols(), Types.ValueType.STRING)));
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index 7ae6b53..9605380 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -664,6 +664,9 @@
///////
// serialization / deserialization (implementation of writable and externalizable)
+ // FIXME for FrameBlock fix write and readFields, it does not work if the Arrays are not yet
+ // allocated (after fixing remove hack in FederatedWorkerHandler.createFrameEncodeMeta(FederatedRequest) call to
+ // FrameBlock.ensureAllocatedColumns())
@Override
public void write(DataOutput out) throws IOException {
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
index 912325f..4c0ad9e 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
@@ -22,8 +22,11 @@
import java.io.Serializable;
import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.wink.json4j.JSONArray;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -123,6 +126,50 @@
public abstract MatrixBlock apply(FrameBlock in, MatrixBlock out);
/**
+ * Returns a new Encoder that only handles a sub range of columns.
+ *
+ * @param colStart the start index of the sub-range (1-based, inclusive)
+ * @param colEnd the end index of the sub-range (1-based, exclusive)
+ * @return an encoder of the same type, just for the sub-range
+ */
+ public Encoder subRangeEncoder(int colStart, int colEnd) {
+ throw new DMLRuntimeException(
+ this.getClass().getName() + " does not support the creation of a sub-range encoder");
+ }
+
+ /**
+ * Merges the column information, like how many columns the frame needs and which columns this encoder operates on.
+ *
+ * @param other the other encoder of the same type
+ * @param col column at which the second encoder will be merged in (1-based)
+ */
+ protected void mergeColumnInfo(Encoder other, int col) {
+ // update number of columns
+ _clen = Math.max(_colList.length, col - 1 + other.getNumCols());
+
+ // update the new columns that this encoder operates on
+ Set<Integer> colListAgg = new HashSet<>(); // for dedup
+ for(int i : _colList)
+ colListAgg.add(i);
+ for(int i : other._colList)
+ colListAgg.add(col - 1 + i);
+ _colList = colListAgg.stream().mapToInt(i -> i).toArray();
+ }
+
+ /**
+ * Merges another encoder, of a compatible type, in after a certain position. Resizes as necessary.
+ * <code>Encoders</code> are compatible with themselves and <code>EncoderComposite</code> is compatible with every
+ * other <code>Encoder</code>.
+ *
+ * @param other the encoder that should be merged in
+ * @param col the position where it should be placed (1-based)
+ */
+ public void mergeAt(Encoder other, int col) {
+ throw new DMLRuntimeException(
+ this.getClass().getName() + " does not support merging with " + other.getClass().getName());
+ }
+
+ /**
* Construct a frame block out of the transform meta data.
*
* @param out output frame block
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
index 08272e0..e653307 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
@@ -19,10 +19,12 @@
package org.apache.sysds.runtime.transform.encode;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -39,7 +41,7 @@
private List<Encoder> _encoders = null;
private FrameBlock _meta = null;
- protected EncoderComposite(List<Encoder> encoders) {
+ public EncoderComposite(List<Encoder> encoders) {
super(null, -1);
_encoders = encoders;
}
@@ -100,7 +102,49 @@
}
return out;
}
-
+
+ @Override
+ public Encoder subRangeEncoder(int colStart, int colEnd) {
+ List<Encoder> subRangeEncoders = new ArrayList<>();
+ for (Encoder encoder : _encoders) {
+ Encoder subEncoder = encoder.subRangeEncoder(colStart, colEnd);
+ if (subEncoder != null) {
+ subRangeEncoders.add(subEncoder);
+ }
+ }
+ return new EncoderComposite(subRangeEncoders);
+ }
+
+ @Override
+ public void mergeAt(Encoder other, int col) {
+ if (other instanceof EncoderComposite) {
+ EncoderComposite otherComposite = (EncoderComposite) other;
+ // TODO maybe assert that the _encoders never have the same type of encoder twice or more
+ for (Encoder otherEnc : otherComposite.getEncoders()) {
+ boolean mergedIn = false;
+ for (Encoder encoder : _encoders) {
+ if (encoder.getClass() == otherEnc.getClass()) {
+ encoder.mergeAt(otherEnc, col);
+ mergedIn = true;
+ break;
+ }
+ }
+ if(!mergedIn) {
+ throw new DMLRuntimeException("Tried to merge in encoder of class that is not present in "
+ + "CompositeEncoder: " + otherEnc.getClass().getSimpleName());
+ }
+ }
+ return;
+ }
+ for (Encoder encoder : _encoders) {
+ if (encoder.getClass() == other.getClass()) {
+ encoder.mergeAt(other, col);
+ return;
+ }
+ }
+ super.mergeAt(other, col);
+ }
+
@Override
public FrameBlock getMetaData(FrameBlock out) {
if( _meta != null )
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index b7443f4..dcd2b1c 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -41,6 +41,11 @@
public static Encoder createEncoder(String spec, String[] colnames, int clen, FrameBlock meta) {
return createEncoder(spec, colnames, UtilFunctions.nCopies(clen, ValueType.STRING), meta);
}
+
+ public static Encoder createEncoder(String spec, String[] colnames, int clen, FrameBlock meta, int minCol,
+ int maxCol) {
+ return createEncoder(spec, colnames, UtilFunctions.nCopies(clen, ValueType.STRING), meta, minCol, maxCol);
+ }
public static Encoder createEncoder(String spec, String[] colnames, ValueType[] schema, int clen, FrameBlock meta) {
ValueType[] lschema = (schema==null) ? UtilFunctions.nCopies(clen, ValueType.STRING) : schema;
@@ -48,6 +53,11 @@
}
public static Encoder createEncoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta) {
+ return createEncoder(spec, colnames, schema, meta, -1, -1);
+ }
+
+ public static Encoder createEncoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta, int minCol,
+ int maxCol) {
Encoder encoder = null;
int clen = schema.length;
@@ -55,21 +65,21 @@
//parse transform specification
JSONObject jSpec = new JSONObject(spec);
List<Encoder> lencoders = new ArrayList<>();
-
+
//prepare basic id lists (recode, feature hash, dummycode, pass-through)
List<Integer> rcIDs = Arrays.asList(ArrayUtils.toObject(
- TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.RECODE.toString())));
+ TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.RECODE.toString(), minCol, maxCol)));
List<Integer>haIDs = Arrays.asList(ArrayUtils.toObject(
- TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.HASH.toString())));
+ TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.HASH.toString(), minCol, maxCol)));
List<Integer> dcIDs = Arrays.asList(ArrayUtils.toObject(
- TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString())));
+ TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol)));
List<Integer> binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames);
//note: any dummycode column requires recode as preparation, unless it follows binning
rcIDs = except(unionDistinct(rcIDs, except(dcIDs, binIDs)), haIDs);
List<Integer> ptIDs = except(except(UtilFunctions.getSeqList(1, clen, 1),
unionDistinct(rcIDs,haIDs)), binIDs);
List<Integer> oIDs = Arrays.asList(ArrayUtils.toObject(
- TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.OMIT.toString())));
+ TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.OMIT.toString(), minCol, maxCol)));
List<Integer> mvIDs = Arrays.asList(ArrayUtils.toObject(
TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, TfMethod.IMPUTE.toString())));
@@ -86,7 +96,7 @@
}
if( !ptIDs.isEmpty() )
lencoders.add(new EncoderPassThrough(
- ArrayUtils.toPrimitive(ptIDs.toArray(new Integer[0])), clen));
+ ArrayUtils.toPrimitive(ptIDs.toArray(new Integer[0])), clen));
if( !binIDs.isEmpty() )
lencoders.add(new EncoderBin(jSpec, colnames, schema.length));
if( !dcIDs.isEmpty() )
@@ -105,8 +115,8 @@
//initialize meta data w/ robustness for superset of cols
if( meta != null ) {
String[] colnames2 = meta.getColumnNames();
- if( !TfMetaUtils.isIDSpec(jSpec) && colnames!=null && colnames2!=null
- && !ArrayUtils.isEquals(colnames, colnames2) )
+ if( !TfMetaUtils.isIDSpec(jSpec) && colnames!=null && colnames2!=null
+ && !ArrayUtils.isEquals(colnames, colnames2) )
{
HashMap<String, Integer> colPos = getColumnPositions(colnames2);
//create temporary meta frame block w/ shallow column copy
@@ -115,7 +125,7 @@
for( int i=0; i<colnames.length; i++ ) {
if( !colPos.containsKey(colnames[i]) ) {
throw new DMLRuntimeException("Column name not found in meta data: "
- +colnames[i]+" (meta: "+Arrays.toString(colnames2)+")");
+ + colnames[i]+" (meta: "+Arrays.toString(colnames2)+")");
}
int pos = colPos.get(colnames[i]);
meta2.setColumn(i, meta.getColumn(pos));
@@ -129,7 +139,6 @@
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
-
return encoder;
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
index c5eb873..8b3d36a 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
@@ -19,6 +19,8 @@
package org.apache.sysds.runtime.transform.encode;
+import java.util.ArrayList;
+import java.util.List;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
@@ -38,6 +40,10 @@
protected EncoderPassThrough(int[] ptCols, int clen) {
super(ptCols, clen); //1-based
}
+
+ public EncoderPassThrough() {
+ this(new int[0], 0);
+ }
@Override
public MatrixBlock encode(FrameBlock in, MatrixBlock out) {
@@ -64,6 +70,32 @@
return out;
}
+
+ @Override
+ public Encoder subRangeEncoder(int colStart, int colEnd) {
+ if (colStart - 1 >= _clen)
+ return null;
+
+ List<Integer> colList = new ArrayList<>();
+ for (int col : _colList) {
+ if (col >= colStart && col < colEnd)
+ // add the correct column, removed columns before start
+ colList.add(col - (colStart - 1));
+ }
+ if (colList.isEmpty())
+ // empty encoder -> return null
+ return null;
+ return new EncoderPassThrough(colList.stream().mapToInt(i -> i).toArray(), colEnd - colStart);
+ }
+
+ @Override
+ public void mergeAt(Encoder other, int col) {
+ if(other instanceof EncoderPassThrough) {
+ mergeColumnInfo(other, col);
+ return;
+ }
+ super.mergeAt(other, col);
+ }
@Override
public FrameBlock getMetaData(FrameBlock meta) {
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
index 8a8bbef..fe3f5b1 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
@@ -19,9 +19,12 @@
package org.apache.sysds.runtime.transform.encode;
+import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
import java.util.Map.Entry;
import org.apache.wink.json4j.JSONException;
@@ -47,6 +50,14 @@
_colList = TfMetaUtils.parseJsonIDList(parsedSpec, colnames, TfMethod.RECODE.toString());
}
+ private EncoderRecode(int[] colList, int clen) {
+ super(colList, clen);
+ }
+
+ public EncoderRecode() {
+ this(new int[0], 0);
+ }
+
public HashMap<Integer, HashMap<String,Long>> getCPRecodeMaps() {
return _rcdMaps;
}
@@ -148,6 +159,62 @@
}
@Override
+ public Encoder subRangeEncoder(int colStart, int colEnd) {
+ if (colStart - 1 >= _clen)
+ return null;
+
+ List<Integer> cols = new ArrayList<>();
+ HashMap<Integer, HashMap<String, Long>> rcdMaps = new HashMap<>();
+ for (int col : _colList) {
+ if (col >= colStart && col < colEnd) {
+ // add the correct column, removed columns before start
+ // colStart - 1 because colStart is 1-based
+ int corrColumn = col - (colStart - 1);
+ cols.add(corrColumn);
+ // copy rcdMap for column
+ rcdMaps.put(corrColumn, new HashMap<>(_rcdMaps.get(col)));
+ }
+ }
+ if (cols.isEmpty())
+ // empty encoder -> sub range encoder does not exist
+ return null;
+
+ int[] colList = cols.stream().mapToInt(i -> i).toArray();
+ EncoderRecode subRangeEncoder = new EncoderRecode(colList, colEnd - colStart);
+ subRangeEncoder._rcdMaps = rcdMaps;
+ return subRangeEncoder;
+ }
+
+ @Override
+ public void mergeAt(Encoder other, int col) {
+ if(other instanceof EncoderRecode) {
+ mergeColumnInfo(other, col);
+
+ // merge together overlapping columns or add new columns
+ EncoderRecode otherRec = (EncoderRecode) other;
+ for (int otherColID : other._colList) {
+ int colID = otherColID + col - 1;
+ //allocate column map if necessary
+ if( !_rcdMaps.containsKey(colID) )
+ _rcdMaps.put(colID, new HashMap<>());
+
+ HashMap<String, Long> otherMap = otherRec._rcdMaps.get(otherColID);
+ if(otherMap != null) {
+ // for each column, add all non present recode values
+ for(Map.Entry<String, Long> entry : otherMap.entrySet()) {
+ if (lookupRCDMap(colID, entry.getKey()) == -1) {
+ // key does not yet exist
+ putCode(_rcdMaps.get(colID), entry.getKey());
+ }
+ }
+ }
+ }
+ return;
+ }
+ super.mergeAt(other, col);
+ }
+
+ @Override
public FrameBlock getMetaData(FrameBlock meta) {
if( !isApplicable() )
return meta;
diff --git a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
index cb9ad41..39f5650 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
@@ -25,18 +25,16 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map.Entry;
import org.apache.commons.lang.ArrayUtils;
-import org.apache.wink.json4j.JSONArray;
-import org.apache.wink.json4j.JSONException;
-import org.apache.wink.json4j.JSONObject;
import org.apache.sysds.api.jmlc.Connection;
-import org.apache.sysds.lops.Lop;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
@@ -47,6 +45,9 @@
import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.wink.json4j.JSONArray;
+import org.apache.wink.json4j.JSONException;
+import org.apache.wink.json4j.JSONObject;
public class TfMetaUtils
{
@@ -80,17 +81,33 @@
/**
* TODO consolidate external and internal json spec definitions
- *
+ *
* @param spec transform specification as json string
* @param colnames column names
- * @param group ?
+ * @param group attribute name in json class
* @return list of column ids
* @throws JSONException if JSONException occurs
*/
- public static int[] parseJsonIDList(JSONObject spec, String[] colnames, String group)
+ public static int[] parseJsonIDList(JSONObject spec, String[] colnames, String group)
+ throws JSONException
+ {
+ return parseJsonIDList(spec, colnames, group, -1, -1);
+ }
+
+ /**
+ * @param spec transform specification as json string
+ * @param colnames column names
+ * @param group attribute name in json class
+ * @param minCol start of columns to ignore (1-based, inclusive, if -1 not used)
+ * @param maxCol end of columns to ignore (1-based, exclusive, if -1 not used)
+ * @return list of column ids
+ * @throws JSONException if JSONException occurs
+ */
+ public static int[] parseJsonIDList(JSONObject spec, String[] colnames, String group, int minCol, int maxCol)
throws JSONException
{
- int[] colList = new int[0];
+ List<Integer> colList = new ArrayList<>();
+ int[] arr = new int[0];
boolean ids = spec.containsKey("ids") && spec.getBoolean("ids");
if( spec.containsKey(group) ) {
@@ -104,21 +121,35 @@
attrs = (JSONArray)spec.get(group);
//construct ID list array
- colList = new int[attrs.size()];
- for(int i=0; i < colList.length; i++) {
- colList[i] = ids ? UtilFunctions.toInt(attrs.get(i)) :
- (ArrayUtils.indexOf(colnames, attrs.get(i)) + 1);
- if( colList[i] <= 0 ) {
- throw new RuntimeException("Specified column '" +
- attrs.get(i)+"' does not exist.");
+ for(int i=0; i < attrs.length(); i++) {
+ int ix;
+ if (ids) {
+ ix = UtilFunctions.toInt(attrs.get(i));
+ if(maxCol != -1 && ix >= maxCol)
+ ix = -1;
+ if(minCol != -1 && ix >= 0)
+ ix -= minCol - 1;
}
+ else {
+ ix = ArrayUtils.indexOf(colnames, attrs.get(i)) + 1;
+ }
+ if(ix <= 0) {
+ if (minCol == -1 && maxCol == -1) {
+ // only if we cut of some columns, ix -1 is expected
+ throw new RuntimeException("Specified column '"
+ + attrs.get(i)+"' does not exist.");
+ }
+ else // ignore column
+ continue;
+ }
+ colList.add(ix);
}
-
+
//ensure ascending order of column IDs
- Arrays.sort(colList);
+ arr = colList.stream().mapToInt((i) -> i)
+ .sorted().toArray();
}
-
- return colList;
+ return arr;
}
public static int[] parseJsonObjectIDList(JSONObject spec, String[] colnames, String group)
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
index b4ce148..aa88027 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
@@ -91,7 +91,7 @@
/*
* like other federated functionality, SPARK execution mode is not yet working (waiting for better integration of
* federated instruction building, like propagating information that object is federated)
- *
+ *
* @Test public void federatedFrameConstructionSP() throws IOException {
* federatedFrameConstruction(Types.ExecMode.SPARK); }
*/
diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformFederatedEncodeDecodeTest.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformFederatedEncodeDecodeTest.java
new file mode 100644
index 0000000..df9def8
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformFederatedEncodeDecodeTest.java
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.transform;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.io.FrameReader;
+import org.apache.sysds.runtime.io.FrameReaderFactory;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase {
+ private static final String TEST_NAME1 = "TransformFederatedEncodeDecode";
+ private static final String TEST_DIR = "functions/transform/";
+ private static final String TEST_CLASS_DIR = TEST_DIR+TransformFederatedEncodeDecodeTest.class.getSimpleName()+"/";
+
+ private static final String SPEC = "TransformEncodeDecodeSpec.json";
+
+ private static final int rows = 1234;
+ private static final int cols = 2;
+ private static final double sparsity1 = 0.9;
+ private static final double sparsity2 = 0.1;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"FO"}));
+ }
+
+ @Test
+ public void runTestCSVDenseCP() {
+ runTransformEncodeDecodeTest(false, Types.FileFormat.CSV);
+ }
+
+ @Test
+ public void runTestCSVSparseCP() {
+ runTransformEncodeDecodeTest(true, Types.FileFormat.CSV);
+ }
+
+ @Test
+ public void runTestTextcellDenseCP() {
+ runTransformEncodeDecodeTest(false, Types.FileFormat.TEXT);
+ }
+
+ @Test
+ public void runTestTextcellSparseCP() {
+ runTransformEncodeDecodeTest(true, Types.FileFormat.TEXT);
+ }
+
+ @Test
+ public void runTestBinaryDenseCP() {
+ runTransformEncodeDecodeTest(false, Types.FileFormat.BINARY);
+ }
+
+ @Test
+ public void runTestBinarySparseCP() {
+ runTransformEncodeDecodeTest(true, Types.FileFormat.BINARY);
+ }
+
+ private void runTransformEncodeDecodeTest(boolean sparse, Types.FileFormat format) {
+ ExecMode platformOld = rtplatform;
+ rtplatform = ExecMode.SINGLE_NODE;
+
+ Thread t1 = null, t2 = null, t3 = null, t4 = null;
+ try {
+ getAndLoadTestConfiguration(TEST_NAME1);
+
+ int port1 = getRandomAvailablePort();
+ t1 = startLocalFedWorker(port1);
+ int port2 = getRandomAvailablePort();
+ t2 = startLocalFedWorker(port2);
+ int port3 = getRandomAvailablePort();
+ t3 = startLocalFedWorker(port3);
+ int port4 = getRandomAvailablePort();
+ t4 = startLocalFedWorker(port4);
+
+ // schema
+ Types.ValueType[] schema = new Types.ValueType[cols / 2];
+ Arrays.fill(schema, Types.ValueType.FP64);
+ // generate and write input data
+ // A is the data that will be aggregated and not recoded
+ double[][] A = TestUtils.round(getRandomMatrix(rows, cols / 2, 1, 15, sparse ? sparsity2 : sparsity1, 7));
+ double[][] AUpper = Arrays.copyOf(A, rows / 2);
+ double[][] ALower = Arrays.copyOfRange(A, rows / 2, rows);
+ writeInputFrameWithMTD("AU", AUpper, false, schema, format);
+ writeInputFrameWithMTD("AL", ALower, false, schema, format);
+
+ // B will be recoded and will be the column that will be grouped by
+ Arrays.fill(schema, Types.ValueType.STRING);
+ // we set sparsity to 1.0 to ensure all the string labels exist
+ double[][] B = TestUtils.round(getRandomMatrix(rows, cols / 2, 1, 15, 1.0, 8));
+ double[][] BUpper = Arrays.copyOf(B, rows / 2);
+ double[][] BLower = Arrays.copyOfRange(B, rows / 2, rows);
+ writeInputFrameWithMTD("BU", BUpper, false, schema, format);
+ writeInputFrameWithMTD("BL", BLower, false, schema, format);
+
+ fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME1 + ".dml";
+
+ programArgs = new String[] {"-nvargs",
+ "in_AU=" + TestUtils.federatedAddress("localhost", port1, input("AU")),
+ "in_AL=" + TestUtils.federatedAddress("localhost", port2, input("AL")),
+ "in_BU=" + TestUtils.federatedAddress("localhost", port3, input("BU")),
+ "in_BL=" + TestUtils.federatedAddress("localhost", port4, input("BL")), "rows=" + rows, "cols=" + cols,
+ "spec_file=" + SCRIPT_DIR + TEST_DIR + SPEC, "out=" + output("FO"), "format=" + format.toString()};
+
+ // run test
+ runTest(true, false, null, -1);
+
+ // compare matrices (values recoded to identical codes)
+ FrameReader reader = FrameReaderFactory.createFrameReader(format);
+ FrameBlock FO = reader.readFrameFromHDFS(output("FO"), 15, 2);
+ HashMap<String, Long> cFA = getCounts(A, B);
+ Iterator<String[]> iterFO = FO.getStringRowIterator();
+ while(iterFO.hasNext()) {
+ String[] row = iterFO.next();
+ Double expected = (double) cFA.get(row[1]);
+ Double val = (row[0] != null) ? Double.parseDouble(row[0]) : 0;
+ Assert.assertEquals("Output aggregates don't match: " + expected + " vs " + val, expected, val);
+ }
+ }
+ catch(Exception ex) {
+ ex.printStackTrace();
+ Assert.fail(ex.getMessage());
+ }
+ finally {
+ TestUtils.shutdownThread(t1);
+ TestUtils.shutdownThread(t2);
+ TestUtils.shutdownThread(t3);
+ TestUtils.shutdownThread(t4);
+ rtplatform = platformOld;
+ }
+ }
+
+ private static HashMap<String, Long> getCounts(double[][] countFrame, double[][] groupFrame) {
+ HashMap<String, Long> ret = new HashMap<>();
+ for(int i = 0; i < countFrame.length; i++) {
+ String key = "Str" + groupFrame[i][0];
+ Long tmp = ret.get(key);
+ ret.put(key, (tmp != null) ? tmp + 1 : 1);
+ }
+ return ret;
+ }
+}
diff --git a/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml b/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml
new file mode 100644
index 0000000..1ff5446
--- /dev/null
+++ b/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+F1 = federated(type="frame", addresses=list($in_AU, $in_AL, $in_BU, $in_BL), ranges=
+ list(list(0,0), list($rows / 2, $cols / 2), # AUpper range
+ list($rows / 2, 0), list($rows, $cols / 2), # ALower range
+ list(0, $cols / 2), list($rows / 2, $cols), # BUpper range
+ list($rows / 2, $cols / 2), list($rows, $cols))); # BLower range
+jspec = read($spec_file, data_type="scalar", value_type="string");
+
+[X, M] = transformencode(target=F1, spec=jspec);
+
+A = aggregate(target=X[,1], groups=X[,2], fn="count");
+Ag = cbind(A, seq(1,nrow(A)));
+
+F2 = transformdecode(target=Ag, spec=jspec, meta=M);
+
+write(F2, $out, format=$format);
+