[SYSTEMDS-362] Federated runtime propagation of privacy constraints
* Runtime propagation of privacy constraints
* Privacy level as Enum with three levels: Private, PrivateAggregate,
and None
* Privacy handling in FederatedWorkerHandler preventing private data
from being included in federated response
* Test of privacy handling of different federated request types
* Test of different privacy levels and combinations for Federated L2SVM
Closes #919.
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index f0ef363..24aade1 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -73,7 +73,7 @@
protected ValueType _valueType;
protected boolean _visited = false;
protected DataCharacteristics _dc = new MatrixCharacteristics();
- protected PrivacyConstraint _privacyConstraint = new PrivacyConstraint();
+ protected PrivacyConstraint _privacyConstraint = null;
protected UpdateType _updateType = UpdateType.COPY;
protected ArrayList<Hop> _parent = new ArrayList<>();
diff --git a/src/main/java/org/apache/sysds/parser/BinaryExpression.java b/src/main/java/org/apache/sysds/parser/BinaryExpression.java
index 6c177e2..acccb66 100644
--- a/src/main/java/org/apache/sysds/parser/BinaryExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BinaryExpression.java
@@ -146,7 +146,7 @@
}
// Set privacy of output
- output.setPrivacy(PrivacyPropagator.MergeBinary(
+ output.setPrivacy(PrivacyPropagator.mergeBinary(
getLeft().getOutput().getPrivacy(), getRight().getOutput().getPrivacy()));
this.setOutput(output);
diff --git a/src/main/java/org/apache/sysds/parser/DataExpression.java b/src/main/java/org/apache/sysds/parser/DataExpression.java
index c94532d..779f788 100644
--- a/src/main/java/org/apache/sysds/parser/DataExpression.java
+++ b/src/main/java/org/apache/sysds/parser/DataExpression.java
@@ -37,6 +37,7 @@
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.io.FileFormatPropertiesMM;
import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.JSONHelper;
@@ -1097,10 +1098,8 @@
// set privacy
Expression eprivacy = getVarParam("privacy");
- boolean privacy = false;
- if ( eprivacy != null ) {
- privacy = Boolean.valueOf(eprivacy.toString());
- getOutput().setPrivacy(privacy);
+ if ( eprivacy != null ){
+ getOutput().setPrivacy(PrivacyLevel.valueOf(eprivacy.toString()));
}
// Following dimension checks must be done when data type = MATRIX_DATA_TYPE
@@ -2074,7 +2073,6 @@
if ( key.toString().equalsIgnoreCase(DELIM_HAS_HEADER_ROW)
|| key.toString().equalsIgnoreCase(DELIM_FILL)
|| key.toString().equalsIgnoreCase(DELIM_SPARSE)
- || key.toString().equalsIgnoreCase(PRIVACY)
) {
// parse these parameters as boolean values
BooleanIdentifier boolId = null;
@@ -2096,7 +2094,8 @@
removeVarParam(key.toString());
addVarParam(key.toString(), doubleId);
}
- else if (key.toString().equalsIgnoreCase(DELIM_NA_STRINGS)) {
+ else if (key.toString().equalsIgnoreCase(DELIM_NA_STRINGS)
+ || key.toString().equalsIgnoreCase(PRIVACY)) {
String naStrings = null;
if ( val instanceof String) {
naStrings = val.toString();
diff --git a/src/main/java/org/apache/sysds/parser/Identifier.java b/src/main/java/org/apache/sysds/parser/Identifier.java
index 36d93f2..3ea3252 100644
--- a/src/main/java/org/apache/sysds/parser/Identifier.java
+++ b/src/main/java/org/apache/sysds/parser/Identifier.java
@@ -26,6 +26,7 @@
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.parser.LanguageException.LanguageErrorCodes;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
public abstract class Identifier extends Expression
{
@@ -104,8 +105,8 @@
_nnz = nnzs;
}
- public void setPrivacy(boolean privacy){
- _privacy = new PrivacyConstraint(privacy);
+ public void setPrivacy(PrivacyLevel privacyLevel){
+ _privacy = new PrivacyConstraint(privacyLevel);
}
public void setPrivacy(PrivacyConstraint privacyConstraint){
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 279685b..bba731c 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -51,6 +51,8 @@
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
+import org.apache.sysds.runtime.privacy.PrivacyMonitor;
+import org.apache.sysds.runtime.privacy.PrivacyPropagator;
import org.apache.sysds.utils.JSONHelper;
import org.apache.wink.json4j.JSONObject;
@@ -149,6 +151,7 @@
return new FederatedResponse(FederatedResponse.Type.ERROR, "Could not parse metadata file");
mc.setRows(mtd.getLong(DataExpression.READROWPARAM));
mc.setCols(mtd.getLong(DataExpression.READCOLPARAM));
+ cd = PrivacyPropagator.parseAndSetPrivacyConstraint(cd, mtd);
fmt = FileFormat.safeValueOf(mtd.getString(DataExpression.FORMAT_TYPE));
}
}
@@ -181,6 +184,7 @@
private FederatedResponse executeMatVecMult(long varID, MatrixBlock vector, boolean isMatVecMult) {
MatrixObject matTo = (MatrixObject) _vars.get(varID);
+ matTo = PrivacyMonitor.handlePrivacy(matTo);
MatrixBlock matBlock1 = matTo.acquireReadAndRelease();
// TODO other datatypes
AggregateBinaryOperator ab_op = new AggregateBinaryOperator(
@@ -199,6 +203,7 @@
private FederatedResponse getVariableData(long varID) {
Data dataObject = _vars.get(varID);
+ dataObject = PrivacyMonitor.handlePrivacy(dataObject);
switch (dataObject.getDataType()) {
case TENSOR:
return new FederatedResponse(FederatedResponse.Type.SUCCESS,
@@ -233,6 +238,7 @@
+ dataObject.getDataType().name());
}
MatrixObject matrixObject = (MatrixObject) dataObject;
+ matrixObject = PrivacyMonitor.handlePrivacy(matrixObject);
MatrixBlock matrixBlock = matrixObject.acquireRead();
// create matrix for calculation with correction
MatrixCharacteristics mc = new MatrixCharacteristics();
@@ -270,6 +276,7 @@
private FederatedResponse executeScalarOperation(long varID, ScalarOperator operator) {
Data dataObject = _vars.get(varID);
+ dataObject = PrivacyMonitor.handlePrivacy(dataObject);
if (dataObject.getDataType() != Types.DataType.MATRIX) {
return new FederatedResponse(FederatedResponse.Type.ERROR,
"FederatedWorkerHandler: ScalarOperator dont support "
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
index adcae38..db867ef 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
@@ -138,6 +138,10 @@
privacyConstraint = lop.getPrivacyConstraint();
}
+ public void setPrivacyConstraint(PrivacyConstraint pc){
+ privacyConstraint = pc;
+ }
+
public PrivacyConstraint getPrivacyConstraint(){
return privacyConstraint;
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
index 1fa9d2b..3111042 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
@@ -48,6 +48,14 @@
this.inputs = inputs;
}
+ public CPOperand[] getInputs(){
+ return inputs;
+ }
+
+ public CPOperand getOutput(){
+ return output;
+ }
+
public static BuiltinNaryCPInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
index 1e60eea..82aaa7d 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
@@ -30,6 +30,7 @@
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstructionUtils;
import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.privacy.PrivacyPropagator;
public abstract class CPInstruction extends Instruction
{
@@ -95,6 +96,8 @@
//robustness federated instructions (runtime assignment)
tmp = FEDInstructionUtils.checkAndReplaceCP(tmp, ec);
+
+ tmp = PrivacyPropagator.preprocessInstruction(tmp, ec);
return tmp;
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/QuaternaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/QuaternaryCPInstruction.java
index de30062..7f8ec4d 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/QuaternaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/QuaternaryCPInstruction.java
@@ -96,6 +96,9 @@
throw new DMLRuntimeException("Unexpected opcode in QuaternaryCPInstruction: " + inst);
}
+ public CPOperand getInput4() {
+ return input4;
+ }
@Override
public void processInstruction(ExecutionContext ec) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
index 21b3f637..f40abb0 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
@@ -58,6 +58,10 @@
import org.apache.sysds.runtime.meta.MetaData;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.meta.TensorCharacteristics;
+import org.apache.sysds.runtime.privacy.DMLPrivacyException;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
+import org.apache.sysds.runtime.privacy.PrivacyMonitor;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.runtime.util.ProgramConverter;
@@ -289,6 +293,10 @@
return ret;
}
+ public CPOperand getOutput(){
+ return output;
+ }
+
private static int getArity(VariableOperationCode op) {
switch(op) {
case Write:
@@ -512,71 +520,7 @@
switch ( opcode )
{
case CreateVariable:
- //PRE: for robustness we cleanup existing variables, because a setVariable
- //would cause a buffer pool memory leak as these objects would never be removed
- if(ec.containsVariable(getInput1()))
- processRemoveVariableInstruction(ec, getInput1().getName());
-
- if ( getInput1().getDataType() == DataType.MATRIX ) {
- //create new variable for symbol table and cache
- //(existing objects gets cleared through rmvar instructions)
- String fname = getInput2().getName();
- // check if unique filename needs to be generated
- if( Boolean.parseBoolean(getInput3().getName()) )
- fname = getUniqueFileName(fname);
- MatrixObject obj = new MatrixObject(getInput1().getValueType(), fname);
- //clone meta data because it is updated on copy-on-write, otherwise there
- //is potential for hidden side effects between variables.
- obj.setMetaData((MetaData)metadata.clone());
- obj.setPrivacyConstraints(getPrivacyConstraint());
- obj.setFileFormatProperties(_formatProperties);
- obj.setMarkForLinCache(true);
- obj.enableCleanup(!getInput1().getName()
- .startsWith(org.apache.sysds.lops.Data.PREAD_PREFIX));
- ec.setVariable(getInput1().getName(), obj);
-
- obj.setUpdateType(_updateType);
- if(DMLScript.STATISTICS && _updateType.isInPlace())
- Statistics.incrementTotalUIPVar();
- }
- else if( getInput1().getDataType() == DataType.TENSOR ) {
- //create new variable for symbol table and cache
- //(existing objects gets cleared through rmvar instructions)
- String fname = getInput2().getName();
- // check if unique filename needs to be generated
- if( Boolean.parseBoolean(getInput3().getName()) )
- fname = getUniqueFileName(fname);
- CacheableData<?> obj = new TensorObject(getInput1().getValueType(), fname);
- //clone meta data because it is updated on copy-on-write, otherwise there
- //is potential for hidden side effects between variables.
- obj.setMetaData((MetaData)metadata.clone());
- obj.setFileFormatProperties(_formatProperties);
- obj.enableCleanup(!getInput1().getName()
- .startsWith(org.apache.sysds.lops.Data.PREAD_PREFIX));
- ec.setVariable(getInput1().getName(), obj);
-
- // TODO update
- }
- else if( getInput1().getDataType() == DataType.FRAME ) {
- String fname = getInput2().getName();
- if( Boolean.parseBoolean(getInput3().getName()) )
- fname = getUniqueFileName(fname);
- FrameObject fobj = new FrameObject(fname);
- fobj.setMetaData((MetaData)metadata.clone());
- fobj.setFileFormatProperties(_formatProperties);
- if( _schema != null )
- fobj.setSchema(_schema); //after metadata
- fobj.enableCleanup(!getInput1().getName()
- .startsWith(org.apache.sysds.lops.Data.PREAD_PREFIX));
- ec.setVariable(getInput1().getName(), fobj);
- }
- else if ( getInput1().getDataType() == DataType.SCALAR ){
- //created variable not called for scalars
- ec.setScalarOutput(getInput1().getName(), null);
- }
- else {
- throw new DMLRuntimeException("Unexpected data type: " + getInput1().getDataType());
- }
+ processCreateVariableInstruction(ec);
break;
case AssignVariable:
@@ -598,168 +542,38 @@
break;
case RemoveVariableAndFile:
- // Remove the variable from HashMap _variables, and possibly delete the data on disk.
- boolean del = ( (BooleanObject) ec.getScalarInput(getInput2().getName(), getInput2().getValueType(), true) ).getBooleanValue();
- MatrixObject m = (MatrixObject) ec.removeVariable(getInput1().getName());
-
- if ( !del ) {
- // HDFS file should be retailed after clearData(),
- // therefore data must be exported if dirty flag is set
- if ( m.isDirty() )
- m.exportData();
- }
- else {
- //throw new DMLRuntimeException("rmfilevar w/ true is not expected! " + instString);
- //cleanDataOnHDFS(pb, input1.getName());
- cleanDataOnHDFS( m );
- }
-
- // check if in-memory object can be cleaned up
- if ( !ec.getVariables().hasReferences(m) ) {
- // no other variable in the symbol table points to the same Data object as that of input1.getName()
-
- //remove matrix object from cache
- m.clearData();
- }
-
+ processRemoveVariableAndFileInstruction(ec);
break;
case CastAsScalarVariable: //castAsScalarVariable
- if( getInput1().getDataType().isFrame() ) {
- FrameBlock fBlock = ec.getFrameInput(getInput1().getName());
- if( fBlock.getNumRows()!=1 || fBlock.getNumColumns()!=1 )
- throw new DMLRuntimeException("Dimension mismatch - unable to cast frame '"+getInput1().getName()+"' of dimension ("+fBlock.getNumRows()+" x "+fBlock.getNumColumns()+") to scalar.");
- Object value = fBlock.get(0,0);
- ec.releaseFrameInput(getInput1().getName());
- ec.setScalarOutput(output.getName(),
- ScalarObjectFactory.createScalarObject(fBlock.getSchema()[0], value));
- }
- else if( getInput1().getDataType().isMatrix() ) {
- MatrixBlock mBlock = ec.getMatrixInput(getInput1().getName());
- if( mBlock.getNumRows()!=1 || mBlock.getNumColumns()!=1 )
- throw new DMLRuntimeException("Dimension mismatch - unable to cast matrix '"+getInput1().getName()+"' of dimension ("+mBlock.getNumRows()+" x "+mBlock.getNumColumns()+") to scalar.");
- double value = mBlock.getValue(0,0);
- ec.releaseMatrixInput(getInput1().getName());
- ec.setScalarOutput(output.getName(), new DoubleObject(value));
- }
- else if( getInput1().getDataType().isTensor() ) {
- TensorBlock tBlock = ec.getTensorInput(getInput1().getName());
- if (tBlock.getNumDims() != 2 || tBlock.getNumRows() != 1 || tBlock.getNumColumns() != 1)
- throw new DMLRuntimeException("Dimension mismatch - unable to cast tensor '" + getInput1().getName() + "' to scalar.");
- ValueType vt = !tBlock.isBasic() ? tBlock.getSchema()[0] : tBlock.getValueType();
- ec.setScalarOutput(output.getName(), ScalarObjectFactory
- .createScalarObject(vt, tBlock.get(new int[] {0, 0})));
- ec.releaseTensorInput(getInput1().getName());
- }
- else if( getInput1().getDataType().isList() ) {
- //TODO handling of cleanup status, potentially new object
- ListObject list = (ListObject)ec.getVariable(getInput1().getName());
- ec.setVariable(output.getName(), list.slice(0));
- }
- else {
- throw new DMLRuntimeException("Unsupported data type "
- + "in as.scalar(): "+getInput1().getDataType().name());
- }
+ processCastAsScalarVariableInstruction(ec);
break;
- case CastAsMatrixVariable:{
- if( getInput1().getDataType().isFrame() ) {
- FrameBlock fin = ec.getFrameInput(getInput1().getName());
- MatrixBlock out = DataConverter.convertToMatrixBlock(fin);
- ec.releaseFrameInput(getInput1().getName());
- ec.setMatrixOutput(output.getName(), out);
- }
- else if( getInput1().getDataType().isScalar() ) {
- ScalarObject scalarInput = ec.getScalarInput(
- getInput1().getName(), getInput1().getValueType(), getInput1().isLiteral());
- MatrixBlock out = new MatrixBlock(scalarInput.getDoubleValue());
- ec.setMatrixOutput(output.getName(), out);
- }
- else if( getInput1().getDataType().isList() ) {
- //TODO handling of cleanup status, potentially new object
- ListObject list = (ListObject)ec.getVariable(getInput1().getName());
- if( list.getLength() > 1 ) {
- if( !list.checkAllDataTypes(DataType.SCALAR) )
- throw new DMLRuntimeException("as.matrix over multi-entry list only allows scalars.");
- MatrixBlock out = new MatrixBlock(list.getLength(), 1, false);
- for( int i=0; i<list.getLength(); i++ )
- out.quickSetValue(i, 0, ((ScalarObject)list.slice(i)).getDoubleValue());
- ec.setMatrixOutput(output.getName(), out);
- }
- else {
- //pass through matrix input or create 1x1 matrix for scalar
- Data tmp = list.slice(0);
- if( tmp instanceof ScalarObject && tmp.getValueType()!=ValueType.STRING ) {
- MatrixBlock out = new MatrixBlock(((ScalarObject)tmp).getDoubleValue());
- ec.setMatrixOutput(output.getName(), out);
- }
- else {
- ec.setVariable(output.getName(), tmp);
- }
- }
- }
- else {
- throw new DMLRuntimeException("Unsupported data type "
- + "in as.matrix(): "+getInput1().getDataType().name());
- }
+
+ case CastAsMatrixVariable:
+ processCastAsMatrixVariableInstruction(ec);
break;
- }
- case CastAsFrameVariable:{
- FrameBlock out = null;
- if( getInput1().getDataType()==DataType.SCALAR ) {
- ScalarObject scalarInput = ec.getScalarInput(getInput1());
- out = new FrameBlock(1, getInput1().getValueType());
- out.ensureAllocatedColumns(1);
- out.set(0, 0, scalarInput.getStringValue());
- }
- else { //DataType.FRAME
- MatrixBlock min = ec.getMatrixInput(getInput1().getName());
- out = DataConverter.convertToFrameBlock(min);
- ec.releaseMatrixInput(getInput1().getName());
- }
- ec.setFrameOutput(output.getName(), out);
+
+ case CastAsFrameVariable:
+ processCastAsFrameVariableInstruction(ec);
break;
- }
- case CastAsDoubleVariable:{
- ScalarObject in = ec.getScalarInput(getInput1());
- ec.setScalarOutput(output.getName(), ScalarObjectFactory.castToDouble(in));
+
+ case CastAsDoubleVariable:
+ ScalarObject scalarDoubleInput = ec.getScalarInput(getInput1());
+ ec.setScalarOutput(output.getName(), ScalarObjectFactory.castToDouble(scalarDoubleInput));
break;
- }
- case CastAsIntegerVariable:{
- ScalarObject in = ec.getScalarInput(getInput1());
- ec.setScalarOutput(output.getName(), ScalarObjectFactory.castToLong(in));
+
+ case CastAsIntegerVariable:
+ ScalarObject scalarLongInput = ec.getScalarInput(getInput1());
+ ec.setScalarOutput(output.getName(), ScalarObjectFactory.castToLong(scalarLongInput));
break;
- }
- case CastAsBooleanVariable:{
- ScalarObject scalarInput = ec.getScalarInput(getInput1());
- ec.setScalarOutput(output.getName(), new BooleanObject(scalarInput.getBooleanValue()));
+
+ case CastAsBooleanVariable:
+ ScalarObject scalarBooleanInput = ec.getScalarInput(getInput1());
+ ec.setScalarOutput(output.getName(), new BooleanObject(scalarBooleanInput.getBooleanValue()));
break;
- }
case Read:
- ScalarObject res = null;
- try {
- switch(getInput1().getValueType()) {
- case FP64:
- res = new DoubleObject(HDFSTool.readDoubleFromHDFSFile(getInput2().getName()));
- break;
- case INT64:
- res = new IntObject(HDFSTool.readIntegerFromHDFSFile(getInput2().getName()));
- break;
- case BOOLEAN:
- res = new BooleanObject(HDFSTool.readBooleanFromHDFSFile(getInput2().getName()));
- break;
- case STRING:
- res = new StringObject(HDFSTool.readStringFromHDFSFile(getInput2().getName()));
- break;
- default:
- throw new DMLRuntimeException("Invalid value type ("
- + getInput1().getValueType() + ") while processing readScalar instruction.");
- }
- } catch ( IOException e ) {
- throw new DMLRuntimeException(e);
- }
- ec.setScalarOutput(getInput1().getName(), res);
-
+ processReadInstruction(ec);
break;
case Write:
@@ -767,23 +581,82 @@
break;
case SetFileName:
- Data data = ec.getVariable(getInput1().getName());
- if ( data.getDataType() == DataType.MATRIX ) {
- if ( getInput3().getName().equalsIgnoreCase("remote") ) {
- ((MatrixObject)data).setFileName(getInput2().getName());
- }
- else {
- throw new DMLRuntimeException("Invalid location (" + getInput3().getName() + ") in SetFileName instruction: " + instString);
- }
- } else{
- throw new DMLRuntimeException("Invalid data type (" + getInput1().getDataType() + ") in SetFileName instruction: " + instString);
- }
+ processSetFileNameInstruction(ec);
break;
default:
throw new DMLRuntimeException("Unknown opcode: " + opcode );
}
}
+
+ /**
+ * Handler for processInstruction "CreateVariable" case
+ * @param ec execution context of the instruction
+ */
+ private void processCreateVariableInstruction(ExecutionContext ec){
+ //PRE: for robustness we cleanup existing variables, because a setVariable
+ //would cause a buffer pool memory leak as these objects would never be removed
+ if(ec.containsVariable(getInput1()))
+ processRemoveVariableInstruction(ec, getInput1().getName());
+
+ switch(getInput1().getDataType()) {
+ case MATRIX: {
+ String fname = createUniqueFilename();
+ MatrixObject obj = new MatrixObject(getInput1().getValueType(), fname);
+ setCacheableDataFields(obj);
+ obj.setUpdateType(_updateType);
+ obj.setMarkForLinCache(true);
+ ec.setVariable(getInput1().getName(), obj);
+ if(DMLScript.STATISTICS && _updateType.isInPlace())
+ Statistics.incrementTotalUIPVar();
+ break;
+ }
+ case TENSOR: {
+ String fname = createUniqueFilename();
+ TensorObject obj = new TensorObject(getInput1().getValueType(), fname);
+ setCacheableDataFields(obj);
+ ec.setVariable(getInput1().getName(), obj);
+ break;
+ }
+ case FRAME: {
+ String fname = createUniqueFilename();
+ FrameObject fobj = new FrameObject(fname);
+ setCacheableDataFields(fobj);
+ if( _schema != null )
+ fobj.setSchema(_schema); //after metadata
+ ec.setVariable(getInput1().getName(), fobj);
+ break;
+ }
+ case SCALAR: {
+ //created variable not called for scalars
+ ec.setScalarOutput(getInput1().getName(), null);
+ break;
+ }
+ default:
+ throw new DMLRuntimeException("Unexpected data type: " + getInput1().getDataType());
+ }
+ }
+
+ private String createUniqueFilename(){
+ //create new variable for symbol table and cache
+ //(existing objects gets cleared through rmvar instructions)
+ String fname = getInput2().getName();
+ // check if unique filename needs to be generated
+ if( Boolean.parseBoolean(getInput3().getName()) ) {
+ fname = getUniqueFileName(fname);
+ }
+ return fname;
+ }
+
+ private void setCacheableDataFields(CacheableData<?> obj){
+ //clone meta data because it is updated on copy-on-write, otherwise there
+ //is potential for hidden side effects between variables.
+ obj.setMetaData((MetaData)metadata.clone());
+ obj.setPrivacyConstraints(getPrivacyConstraint());
+ obj.enableCleanup(!getInput1().getName()
+ .startsWith(org.apache.sysds.lops.Data.PREAD_PREFIX));
+ obj.setFileFormatProperties(_formatProperties);
+ }
/**
* Handler for mvvar instructions.
@@ -825,7 +698,7 @@
if ( ec.getVariable(getInput1().getName()) == null )
throw new DMLRuntimeException("Unexpected error: could not find a data object for variable name:" + getInput1().getName() + ", while processing instruction " +this.toString());
- Object object = ec.getVariable(getInput1().getName());
+ Data object = ec.getVariable(getInput1().getName());
if ( getInput3().getName().equalsIgnoreCase("binaryblock") ) {
boolean success = false;
@@ -843,6 +716,187 @@
+ ((FrameObject)object).getNumColumns() + "," + ((FrameObject)object).getNumColumns() + "] to " + getInput3().getName());
}
}
+
+ /**
+ * Handler for RemoveVariableAndFile instruction
+ *
+ * @param ec execution context
+ */
+ private void processRemoveVariableAndFileInstruction(ExecutionContext ec){
+ // Remove the variable from HashMap _variables, and possibly delete the data on disk.
+ boolean del = ( (BooleanObject) ec.getScalarInput(getInput2().getName(), getInput2().getValueType(), true) ).getBooleanValue();
+ MatrixObject m = (MatrixObject) ec.removeVariable(getInput1().getName());
+
+ if ( !del ) {
+ // HDFS file should be retailed after clearData(),
+ // therefore data must be exported if dirty flag is set
+ if ( m.isDirty() )
+ m.exportData();
+ }
+ else {
+ //throw new DMLRuntimeException("rmfilevar w/ true is not expected! " + instString);
+ //cleanDataOnHDFS(pb, input1.getName());
+ cleanDataOnHDFS( m );
+ }
+
+ // check if in-memory object can be cleaned up
+ if ( !ec.getVariables().hasReferences(m) ) {
+ // no other variable in the symbol table points to the same Data object as that of input1.getName()
+
+ //remove matrix object from cache
+ m.clearData();
+ }
+ }
+
+ /**
+ * Process CastAsScalarVariable instruction.
+ * @param ec execution context
+ */
+ private void processCastAsScalarVariableInstruction(ExecutionContext ec){
+ //TODO: Create privacy constraints for ScalarObject so that the privacy constraints can be propagated to scalars as well.
+ PrivacyMonitor.handlePrivacyScalarOutput(getInput1(), ec);
+
+ switch( getInput1().getDataType() ) {
+ case MATRIX: {
+ MatrixBlock mBlock = ec.getMatrixInput(getInput1().getName());
+ if( mBlock.getNumRows()!=1 || mBlock.getNumColumns()!=1 )
+ throw new DMLRuntimeException("Dimension mismatch - unable to cast matrix '"+getInput1().getName()+"' of dimension ("+mBlock.getNumRows()+" x "+mBlock.getNumColumns()+") to scalar.");
+ double value = mBlock.getValue(0,0);
+ ec.releaseMatrixInput(getInput1().getName());
+ ec.setScalarOutput(output.getName(), new DoubleObject(value));
+ break;
+ }
+ case FRAME: {
+ FrameBlock fBlock = ec.getFrameInput(getInput1().getName());
+ if( fBlock.getNumRows()!=1 || fBlock.getNumColumns()!=1 )
+ throw new DMLRuntimeException("Dimension mismatch - unable to cast frame '"+getInput1().getName()+"' of dimension ("+fBlock.getNumRows()+" x "+fBlock.getNumColumns()+") to scalar.");
+ Object value = fBlock.get(0,0);
+ ec.releaseFrameInput(getInput1().getName());
+ ec.setScalarOutput(output.getName(),
+ ScalarObjectFactory.createScalarObject(fBlock.getSchema()[0], value));
+ break;
+ }
+ case TENSOR: {
+ TensorBlock tBlock = ec.getTensorInput(getInput1().getName());
+ if (tBlock.getNumDims() != 2 || tBlock.getNumRows() != 1 || tBlock.getNumColumns() != 1)
+ throw new DMLRuntimeException("Dimension mismatch - unable to cast tensor '" + getInput1().getName() + "' to scalar.");
+ ValueType vt = !tBlock.isBasic() ? tBlock.getSchema()[0] : tBlock.getValueType();
+ ec.setScalarOutput(output.getName(), ScalarObjectFactory
+ .createScalarObject(vt, tBlock.get(new int[] {0, 0})));
+ ec.releaseTensorInput(getInput1().getName());
+ break;
+ }
+ case LIST: {
+ //TODO handling of cleanup status, potentially new object
+ ListObject list = (ListObject)ec.getVariable(getInput1().getName());
+ ec.setVariable(output.getName(), list.slice(0));
+ break;
+ }
+ default:
+ throw new DMLRuntimeException("Unsupported data type "
+ + "in as.scalar(): "+getInput1().getDataType().name());
+ }
+ }
+
+ /**
+ * Handler for CastAsMatrixVariable instruction
+ * @param ec execution context
+ */
+ private void processCastAsMatrixVariableInstruction(ExecutionContext ec) {
+ switch( getInput1().getDataType() ) {
+ case FRAME: {
+ FrameBlock fin = ec.getFrameInput(getInput1().getName());
+ MatrixBlock out = DataConverter.convertToMatrixBlock(fin);
+ ec.releaseFrameInput(getInput1().getName());
+ ec.setMatrixOutput(output.getName(), out);
+ break;
+ }
+ case SCALAR: {
+ ScalarObject scalarInput = ec.getScalarInput(
+ getInput1().getName(), getInput1().getValueType(), getInput1().isLiteral());
+ MatrixBlock out = new MatrixBlock(scalarInput.getDoubleValue());
+ ec.setMatrixOutput(output.getName(), out);
+ break;
+ }
+ case LIST: {
+ //TODO handling of cleanup status, potentially new object
+ ListObject list = (ListObject)ec.getVariable(getInput1().getName());
+ if( list.getLength() > 1 ) {
+ if( !list.checkAllDataTypes(DataType.SCALAR) )
+ throw new DMLRuntimeException("as.matrix over multi-entry list only allows scalars.");
+ MatrixBlock out = new MatrixBlock(list.getLength(), 1, false);
+ for( int i=0; i<list.getLength(); i++ )
+ out.quickSetValue(i, 0, ((ScalarObject)list.slice(i)).getDoubleValue());
+ ec.setMatrixOutput(output.getName(), out);
+ }
+ else {
+ //pass through matrix input or create 1x1 matrix for scalar
+ Data tmp = list.slice(0);
+ if( tmp instanceof ScalarObject && tmp.getValueType()!=ValueType.STRING ) {
+ MatrixBlock out = new MatrixBlock(((ScalarObject)tmp).getDoubleValue());
+ ec.setMatrixOutput(output.getName(), out);
+ }
+ else {
+ ec.setVariable(output.getName(), tmp);
+ }
+ }
+ break;
+ }
+ default:
+ throw new DMLRuntimeException("Unsupported data type "
+ + "in as.matrix(): "+getInput1().getDataType().name());
+ }
+ }
+
+ /**
+ * Handler for CastAsFrameVariable instruction
+ * @param ec execution context
+ */
+ private void processCastAsFrameVariableInstruction(ExecutionContext ec){
+ FrameBlock out = null;
+ if( getInput1().getDataType()==DataType.SCALAR ) {
+ ScalarObject scalarInput = ec.getScalarInput(getInput1());
+ out = new FrameBlock(1, getInput1().getValueType());
+ out.ensureAllocatedColumns(1);
+ out.set(0, 0, scalarInput.getStringValue());
+ }
+ else { //DataType.FRAME
+ MatrixBlock min = ec.getMatrixInput(getInput1().getName());
+ out = DataConverter.convertToFrameBlock(min);
+ ec.releaseMatrixInput(getInput1().getName());
+ }
+ ec.setFrameOutput(output.getName(), out);
+ }
+
+ /**
+ * Handler for Read instruction
+ * @param ec execution context
+ */
+ private void processReadInstruction(ExecutionContext ec){
+ ScalarObject res = null;
+ try {
+ switch(getInput1().getValueType()) {
+ case FP64:
+ res = new DoubleObject(HDFSTool.readDoubleFromHDFSFile(getInput2().getName()));
+ break;
+ case INT64:
+ res = new IntObject(HDFSTool.readIntegerFromHDFSFile(getInput2().getName()));
+ break;
+ case BOOLEAN:
+ res = new BooleanObject(HDFSTool.readBooleanFromHDFSFile(getInput2().getName()));
+ break;
+ case STRING:
+ res = new StringObject(HDFSTool.readStringFromHDFSFile(getInput2().getName()));
+ break;
+ default:
+ throw new DMLRuntimeException("Invalid value type ("
+ + getInput1().getValueType() + ") while processing readScalar instruction.");
+ }
+ } catch ( IOException e ) {
+ throw new DMLRuntimeException(e);
+ }
+ ec.setScalarOutput(getInput1().getName(), res);
+ }
/**
* Handler for cpvar instructions.
@@ -898,20 +952,38 @@
else {
// Default behavior
MatrixObject mo = ec.getMatrixObject(getInput1().getName());
- mo.setPrivacyConstraints(getPrivacyConstraint());
mo.exportData(fname, fmtStr, _formatProperties);
}
+ // Set privacy constraint of write instruction to the same as that of the input
+ setPrivacyConstraint(ec.getMatrixObject(getInput1().getName()).getPrivacyConstraint());
}
else if( getInput1().getDataType() == DataType.FRAME ) {
FrameObject mo = ec.getFrameObject(getInput1().getName());
mo.exportData(fname, fmtStr, _formatProperties);
+ setPrivacyConstraint(mo.getPrivacyConstraint());
}
else if( getInput1().getDataType() == DataType.TENSOR ) {
// TODO write tensor
TensorObject to = ec.getTensorObject(getInput1().getName());
+ setPrivacyConstraint(to.getPrivacyConstraint());
to.exportData(fname, fmtStr, _formatProperties);
}
}
+
+ /**
+ * Handler for SetFileName instruction
+ * @param ec execution context
+ */
+ private void processSetFileNameInstruction(ExecutionContext ec){
+ Data data = ec.getVariable(getInput1().getName());
+ if ( data.getDataType() == DataType.MATRIX ) {
+ if ( getInput3().getName().equalsIgnoreCase("remote") )
+ ((MatrixObject)data).setFileName(getInput2().getName());
+ else
+ throw new DMLRuntimeException("Invalid location (" + getInput3().getName() + ") in SetFileName instruction: " + instString);
+ } else
+ throw new DMLRuntimeException("Invalid data type (" + getInput1().getDataType() + ") in SetFileName instruction: " + instString);
+ }
/**
* Remove variable instruction externalized as a static function in order to allow various
@@ -956,7 +1028,7 @@
else {
mo.exportData(fname, outFmt, _formatProperties);
}
- HDFSTool.writeMetaDataFile (fname + ".mtd", mo.getValueType(), dc, FileFormat.CSV, _formatProperties);
+ HDFSTool.writeMetaDataFile (fname + ".mtd", mo.getValueType(), dc, FileFormat.CSV, _formatProperties, mo.getPrivacyConstraint());
}
catch (IOException e) {
throw new DMLRuntimeException(e);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 9000200..fc064eb 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -22,6 +22,7 @@
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.privacy.PrivacyPropagator;
public abstract class FEDInstruction extends Instruction {
@@ -58,6 +59,8 @@
@Override
public Instruction preprocessInstruction(ExecutionContext ec) {
- return super.preprocessInstruction(ec);
+ Instruction tmp = super.preprocessInstruction(ec);
+ tmp = PrivacyPropagator.preprocessInstruction(tmp, ec);
+ return tmp;
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReblockSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReblockSPInstruction.java
index 99602ae..cf0d162 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReblockSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReblockSPInstruction.java
@@ -88,7 +88,7 @@
DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
mcOut.set(mc.getRows(), mc.getCols(), blen, mc.getNonZeros());
- //get the source format form the meta data
+ //get the source format from the meta data
MetaDataFormat iimd = (MetaDataFormat) obj.getMetaData();
if(iimd == null)
throw new DMLRuntimeException("Error: Metadata not found");
diff --git a/src/main/java/org/apache/sysds/runtime/privacy/DMLPrivacyException.java b/src/main/java/org/apache/sysds/runtime/privacy/DMLPrivacyException.java
new file mode 100644
index 0000000..7e77b04
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/privacy/DMLPrivacyException.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.privacy;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+
+/**
+ * This exception should be thrown to flag DML runtime errors related to the violation of privacy constraints.
+ */
+public class DMLPrivacyException extends DMLRuntimeException
+{
+ private static final long serialVersionUID = 1L;
+
+ //prevent string concatenation of classname w/ stop message
+ private DMLPrivacyException(Exception e) {
+ super(e);
+ }
+
+ private DMLPrivacyException(String string, Exception ex){
+ super(string,ex);
+ }
+
+ /**
+ * This is the only valid constructor for DMLPrivacyException.
+ *
+ * @param msg message
+ */
+ public DMLPrivacyException(String msg) {
+ super(msg);
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java
index 2b32636..45b12be 100644
--- a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java
+++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java
@@ -24,19 +24,25 @@
*/
public class PrivacyConstraint
{
- protected boolean _privacy = false;
+ public enum PrivacyLevel {
+ None, // No data exchange constraints. Data can be shared with anyone.
+ Private, // Data cannot leave the origin.
+ PrivateAggregation // Only aggregations of the data can leave the origin.
+ }
- public PrivacyConstraint(){}
+ protected PrivacyLevel privacyLevel = PrivacyLevel.None;
- public PrivacyConstraint(boolean privacy) {
- _privacy = privacy;
- }
+ public PrivacyConstraint(){}
- public void setPrivacy(boolean privacy){
- _privacy = privacy;
- }
+ public PrivacyConstraint(PrivacyLevel privacyLevel) {
+ setPrivacyLevel(privacyLevel);
+ }
- public boolean getPrivacy(){
- return _privacy;
- }
-}
\ No newline at end of file
+ public void setPrivacyLevel(PrivacyLevel privacyLevel){
+ this.privacyLevel = privacyLevel;
+ }
+
+ public PrivacyLevel getPrivacyLevel(){
+ return privacyLevel;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java
new file mode 100644
index 0000000..118a153
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.privacy;
+
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
+
+public class PrivacyMonitor
+{
+ //TODO maybe maintain a log of checked constaints for transfers
+ // in order to provide 'privacy explanations' similar to our stats
+
+ /**
+ * Throws DMLPrivacyException if data object is CacheableData and privacy constraint is set to private or private aggregation.
+ * @param dataObject input data object
+ * @return data object or data object with privacy constraint removed in case the privacy level was none.
+ */
+ public static Data handlePrivacy(Data dataObject){
+ if ( dataObject instanceof CacheableData<?> ){
+ PrivacyConstraint privacyConstraint = ((CacheableData<?>)dataObject).getPrivacyConstraint();
+ if (privacyConstraint != null){
+ PrivacyLevel privacyLevel = privacyConstraint.getPrivacyLevel();
+ switch(privacyLevel){
+ case None:
+ ((CacheableData<?>)dataObject).setPrivacyConstraints(null);
+ break;
+ case Private:
+ case PrivateAggregation:
+ throw new DMLPrivacyException("Cannot share variable, since the privacy constraint of the requested variable is set to " + privacyLevel.name());
+ default:
+ throw new DMLPrivacyException("Privacy level " + privacyLevel.name() + " of variable not recognized");
+ }
+ }
+ }
+ return dataObject;
+ }
+
+ /**
+ * Throws DMLPrivacyException if privacy constraint of matrix object has level privacy.
+ * @param matrixObject input matrix object
+ * @return matrix object or matrix object with privacy constraint removed in case the privacy level was none.
+ */
+ public static MatrixObject handlePrivacy(MatrixObject matrixObject){
+ PrivacyConstraint privacyConstraint = matrixObject.getPrivacyConstraint();
+ if (privacyConstraint != null){
+ PrivacyLevel privacyLevel = privacyConstraint.getPrivacyLevel();
+ switch(privacyLevel){
+ case None:
+ matrixObject.setPrivacyConstraints(null);
+ break;
+ case Private:
+ throw new DMLPrivacyException("Cannot share variable, since the privacy constraint of the requested variable is set to " + privacyLevel.name());
+ case PrivateAggregation:
+ break;
+ default:
+ throw new DMLPrivacyException("Privacy level " + privacyLevel.name() + " of variable not recognized");
+ }
+ }
+ return matrixObject;
+ }
+
+ /**
+ * Throw DMLPrivacyException if privacy is activated for the input variable
+ * @param input variable for which the privacy constraint is checked
+ */
+ public static void handlePrivacyScalarOutput(CPOperand input, ExecutionContext ec) {
+ Data data = ec.getCacheableData(input);
+ if ( data != null && (data instanceof CacheableData<?>)){
+ PrivacyConstraint privacyConstraintIn = ((CacheableData<?>) data).getPrivacyConstraint();
+ if ( privacyConstraintIn != null && (privacyConstraintIn.getPrivacyLevel() == PrivacyLevel.Private) ){
+ throw new DMLPrivacyException("Privacy constraint cannot be propagated to scalar for input " + input.getName());
+ }
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java
index 2070c99..323330a 100644
--- a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java
+++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java
@@ -19,20 +19,325 @@
package org.apache.sysds.runtime.privacy;
+import java.util.function.Function;
+
+import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.BuiltinNaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.CPInstruction;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
+import org.apache.wink.json4j.JSONException;
+import org.apache.wink.json4j.JSONObject;
+
/**
* Class with static methods merging privacy constraints of operands
* in expressions to generate the privacy constraints of the output.
*/
-public class PrivacyPropagator {
+public class PrivacyPropagator
+{
+ public static CacheableData<?> parseAndSetPrivacyConstraint(CacheableData<?> cd, JSONObject mtd)
+ throws JSONException
+ {
+ if ( mtd.containsKey(DataExpression.PRIVACY) ) {
+ String privacyLevel = mtd.getString(DataExpression.PRIVACY);
+ if ( privacyLevel != null )
+ cd.setPrivacyConstraints(new PrivacyConstraint(PrivacyLevel.valueOf(privacyLevel)));
+ }
+ return cd;
+ }
+
+ public static PrivacyConstraint mergeBinary(PrivacyConstraint privacyConstraint1, PrivacyConstraint privacyConstraint2) {
+ if (privacyConstraint1 != null && privacyConstraint2 != null){
+ PrivacyLevel privacyLevel1 = privacyConstraint1.getPrivacyLevel();
+ PrivacyLevel privacyLevel2 = privacyConstraint2.getPrivacyLevel();
- public static PrivacyConstraint MergeBinary(PrivacyConstraint privacyConstraint1, PrivacyConstraint privacyConstraint2) {
- if (privacyConstraint1 != null && privacyConstraint2 != null)
- return new PrivacyConstraint(
- privacyConstraint1.getPrivacy() || privacyConstraint2.getPrivacy());
+ // One of the inputs are private, hence the output must be private.
+ if (privacyLevel1 == PrivacyLevel.Private || privacyLevel2 == PrivacyLevel.Private)
+ return new PrivacyConstraint(PrivacyLevel.Private);
+ // One of the inputs are private with aggregation allowed, but none of the inputs are completely private,
+ // hence the output must be private with aggregation.
+ else if (privacyLevel1 == PrivacyLevel.PrivateAggregation || privacyLevel2 == PrivacyLevel.PrivateAggregation)
+ return new PrivacyConstraint(PrivacyLevel.PrivateAggregation);
+ // Both inputs have privacy level "None", hence the privacy constraint can be removed.
+ else
+ return null;
+ }
else if (privacyConstraint1 != null)
return privacyConstraint1;
else if (privacyConstraint2 != null)
return privacyConstraint2;
return null;
}
+
+ public static PrivacyConstraint mergeTernary(PrivacyConstraint[] privacyConstraints){
+ return mergeBinary(mergeBinary(privacyConstraints[0], privacyConstraints[1]), privacyConstraints[2]);
+ }
+
+ public static PrivacyConstraint mergeQuaternary(PrivacyConstraint[] privacyConstraints){
+ return mergeBinary(mergeTernary(privacyConstraints), privacyConstraints[3]);
+ }
+
+ public static PrivacyConstraint mergeNary(PrivacyConstraint[] privacyConstraints){
+ PrivacyConstraint mergedPrivacyConstraint = privacyConstraints[0];
+ for ( int i = 1; i < privacyConstraints.length; i++ ){
+ mergedPrivacyConstraint = mergeBinary(mergedPrivacyConstraint, privacyConstraints[i]);
+ }
+ return mergedPrivacyConstraint;
+ }
+
+ public static Instruction preprocessInstruction(Instruction inst, ExecutionContext ec){
+ switch ( inst.getType() ){
+ case CONTROL_PROGRAM:
+ return preprocessCPInstruction( (CPInstruction) inst, ec );
+ case BREAKPOINT:
+ case SPARK:
+ case GPU:
+ case FEDERATED:
+ return inst;
+ default:
+ throwExceptionIfPrivacyActivated(inst, ec);
+ return inst;
+ }
+ }
+
+ public static Instruction preprocessCPInstruction(CPInstruction inst, ExecutionContext ec){
+ switch ( inst.getCPInstructionType() )
+ {
+ case Variable:
+ return preprocessVariableCPInstruction((VariableCPInstruction) inst, ec);
+ case AggregateUnary:
+ case Reorg:
+ case Unary:
+ return preprocessUnaryCPInstruction((UnaryCPInstruction) inst, ec);
+ case AggregateBinary:
+ case Append:
+ case Binary:
+ return preprocessBinaryCPInstruction((BinaryCPInstruction) inst, ec);
+ case AggregateTernary:
+ case Ternary:
+ return preprocessTernaryCPInstruction((ComputationCPInstruction) inst, ec);
+ case Quaternary:
+ return preprocessQuaternary((QuaternaryCPInstruction) inst, ec);
+ case BuiltinNary:
+ case Builtin:
+ return preprocessBuiltinNary((BuiltinNaryCPInstruction) inst, ec);
+ case Ctable:
+ case MultiReturnParameterizedBuiltin:
+ case MultiReturnBuiltin:
+ case ParameterizedBuiltin:
+ default:
+ return preprocessInstructionSimple(inst, ec);
+ }
+ }
+
+ /**
+ * Throw exception if privacy constraints are activated or return instruction if privacy is not activated
+ * @param inst instruction
+ * @param ec execution context
+ * @return instruction
+ */
+ public static Instruction preprocessInstructionSimple(Instruction inst, ExecutionContext ec){
+ throwExceptionIfPrivacyActivated(inst, ec);
+ return inst;
+ }
+
+ public static Instruction preprocessBuiltinNary(BuiltinNaryCPInstruction inst, ExecutionContext ec){
+ if (inst.getInputs() == null) return inst;
+ PrivacyConstraint[] privacyConstraints = getInputPrivacyConstraints(ec, inst.getInputs());
+ PrivacyConstraint mergedPrivacyConstraint = mergeNary(privacyConstraints);
+ inst.setPrivacyConstraint(mergedPrivacyConstraint);
+ setOutputPrivacyConstraint(ec, mergedPrivacyConstraint, inst.getOutput());
+ return inst;
+ }
+
+ public static Instruction preprocessQuaternary(QuaternaryCPInstruction inst, ExecutionContext ec){
+ PrivacyConstraint[] privacyConstraints = getInputPrivacyConstraints(ec,
+ new CPOperand[] {inst.input1,inst.input2,inst.input3,inst.getInput4()});
+ PrivacyConstraint mergedPrivacyConstraint = mergeQuaternary(privacyConstraints);
+ inst.setPrivacyConstraint(mergedPrivacyConstraint);
+ setOutputPrivacyConstraint(ec, mergedPrivacyConstraint, inst.output);
+ return inst;
+ }
+
+ public static Instruction preprocessTernaryCPInstruction(ComputationCPInstruction inst, ExecutionContext ec){
+ PrivacyConstraint[] privacyConstraints = getInputPrivacyConstraints(ec, new CPOperand[]{inst.input1, inst.input2, inst.input3});
+ PrivacyConstraint mergedPrivacyConstraint = mergeTernary(privacyConstraints);
+ inst.setPrivacyConstraint(mergedPrivacyConstraint);
+ setOutputPrivacyConstraint(ec, mergedPrivacyConstraint, inst.output);
+ return inst;
+ }
+
+ public static Instruction preprocessNaryInstruction(CPInstruction inst, ExecutionContext ec, CPOperand[] inputs, CPOperand output, Function<PrivacyConstraint[], PrivacyConstraint> mergeFunction){
+ PrivacyConstraint[] privacyConstraints = getInputPrivacyConstraints(ec, inputs);
+ PrivacyConstraint mergedPrivacyConstraint = mergeFunction.apply(privacyConstraints);
+ inst.setPrivacyConstraint(mergedPrivacyConstraint);
+ setOutputPrivacyConstraint(ec, mergedPrivacyConstraint, output);
+ return inst;
+
+ }
+
+ public static Instruction preprocessBinaryCPInstruction(BinaryCPInstruction inst, ExecutionContext ec){
+ PrivacyConstraint privacyConstraint1 = getInputPrivacyConstraint(ec, inst.input1);
+ PrivacyConstraint privacyConstraint2 = getInputPrivacyConstraint(ec, inst.input2);
+ if ( privacyConstraint1 != null || privacyConstraint2 != null)
+ {
+ PrivacyConstraint mergedPrivacyConstraint = mergeBinary(privacyConstraint1, privacyConstraint2);
+ inst.setPrivacyConstraint(mergedPrivacyConstraint);
+ setOutputPrivacyConstraint(ec, mergedPrivacyConstraint, inst.output);
+ }
+ return inst;
+ }
+
+ public static Instruction preprocessUnaryCPInstruction(UnaryCPInstruction inst, ExecutionContext ec){
+ return propagateInputPrivacy(inst, ec, inst.input1, inst.output);
+ }
+
+ public static Instruction preprocessVariableCPInstruction(VariableCPInstruction inst, ExecutionContext ec){
+ switch ( inst.getVariableOpcode() )
+ {
+ case CreateVariable:
+ return propagateSecondInputPrivacy(inst, ec);
+ case AssignVariable:
+ //Assigns scalar, hence it does not have privacy activated
+ return inst;
+ case CopyVariable:
+ case MoveVariable:
+ return propagateFirstInputPrivacy(inst, ec);
+ case RemoveVariable:
+ return propagateAllInputPrivacy(inst, ec);
+ case RemoveVariableAndFile:
+ return propagateFirstInputPrivacy(inst, ec);
+ case CastAsScalarVariable:
+ return propagateCastAsScalarVariablePrivacy(inst, ec);
+ case CastAsMatrixVariable:
+ case CastAsFrameVariable:
+ return propagateFirstInputPrivacy(inst, ec);
+ case CastAsDoubleVariable:
+ case CastAsIntegerVariable:
+ case CastAsBooleanVariable:
+ return propagateCastAsScalarVariablePrivacy(inst, ec);
+ case Read:
+ return inst;
+ case Write:
+ return propagateFirstInputPrivacy(inst, ec);
+ case SetFileName:
+ return propagateFirstInputPrivacy(inst, ec);
+ default:
+ throwExceptionIfPrivacyActivated(inst, ec);
+ return inst;
+ }
+ }
+
+ private static void throwExceptionIfPrivacyActivated(Instruction inst, ExecutionContext ec){
+ if ( inst.getPrivacyConstraint() != null && inst.getPrivacyConstraint().getPrivacyLevel().equals(PrivacyLevel.Private) ) {
+ throw new DMLPrivacyException("Instruction " + inst + " has privacy constraints activated, but the constraints are not propagated during preprocessing of instruction.");
+ }
+ }
+
+ /**
+ * Propagate privacy from first input and throw exception if privacy is activated.
+ * @param inst Instruction
+ * @param ec execution context
+ * @return instruction with or without privacy constraints
+ */
+ private static Instruction propagateCastAsScalarVariablePrivacy(VariableCPInstruction inst, ExecutionContext ec){
+ inst = (VariableCPInstruction) propagateFirstInputPrivacy(inst, ec);
+ return preprocessInstructionSimple(inst, ec);
+ }
+
+ /**
+ * Propagate privacy constraints from all inputs if privacy constraints are set.
+ * @param inst instruction
+ * @param ec execution context
+ * @return instruction with or without privacy constraints
+ */
+ private static Instruction propagateAllInputPrivacy(VariableCPInstruction inst, ExecutionContext ec){
+ //TODO: Propagate the most restricting constraints instead of just the latest activated constraint
+ for ( CPOperand input : inst.getInputs() )
+ inst = (VariableCPInstruction) propagateInputPrivacy(inst, ec, input, inst.getOutput());
+ return inst;
+ }
+
+ /**
+ * Propagate privacy constraint to instruction and output of instruction
+ * if data of first input is CacheableData and
+ * privacy constraint is activated.
+ * @param inst VariableCPInstruction
+ * @param ec execution context
+ * @return instruction with or without privacy constraints
+ */
+ private static Instruction propagateFirstInputPrivacy(VariableCPInstruction inst, ExecutionContext ec){
+ return propagateInputPrivacy(inst, ec, inst.getInput1(), inst.getOutput());
+ }
+
+ /**
+ * Propagate privacy constraint to instruction and output of instruction
+ * if data of second input is CacheableData and
+ * privacy constraint is activated.
+ * @param inst VariableCPInstruction
+ * @param ec execution context
+ * @return instruction with or without privacy constraints
+ */
+ private static Instruction propagateSecondInputPrivacy(VariableCPInstruction inst, ExecutionContext ec){
+ return propagateInputPrivacy(inst, ec, inst.getInput2(), inst.getOutput());
+ }
+
+ /**
+ * Propagate privacy constraint to instruction and output of instruction
+ * if data of the specified variable is CacheableData
+ * and privacy constraint is activated
+ * @param inst instruction
+ * @param ec execution context
+ * @param inputOperand input from which the privacy constraint is found
+ * @param outputOperand output which the privacy constraint is propagated to
+ * @return instruction with or without privacy constraints
+ */
+ private static Instruction propagateInputPrivacy(Instruction inst, ExecutionContext ec, CPOperand inputOperand, CPOperand outputOperand){
+ PrivacyConstraint privacyConstraint = getInputPrivacyConstraint(ec, inputOperand);
+ if ( privacyConstraint != null ) {
+ inst.setPrivacyConstraint(privacyConstraint);
+ if ( outputOperand != null)
+ setOutputPrivacyConstraint(ec, privacyConstraint, outputOperand);
+ }
+ return inst;
+ }
+
+ private static PrivacyConstraint getInputPrivacyConstraint(ExecutionContext ec, CPOperand input){
+ if ( input != null && input.getName() != null){
+ Data dd = ec.getVariable(input.getName());
+ if ( dd != null && dd instanceof CacheableData)
+ return ((CacheableData<?>) dd).getPrivacyConstraint();
+ }
+ return null;
+ }
+
+
+ private static PrivacyConstraint[] getInputPrivacyConstraints(ExecutionContext ec, CPOperand[] inputs){
+ PrivacyConstraint[] privacyConstraints = new PrivacyConstraint[inputs.length];
+ for ( int i = 0; i < inputs.length; i++ ){
+ privacyConstraints[i] = getInputPrivacyConstraint(ec, inputs[i]);
+ }
+ return privacyConstraints;
+
+ }
+
+ private static void setOutputPrivacyConstraint(ExecutionContext ec, PrivacyConstraint privacyConstraint, CPOperand output){
+ Data dd = ec.getVariable(output.getName());
+ if ( dd != null ){
+ if ( dd instanceof CacheableData ){
+ ((CacheableData<?>) dd).setPrivacyConstraints(privacyConstraint);
+ ec.setVariable(output.getName(), dd);
+ }
+ else throw new DMLPrivacyException("Privacy constraint of " + output + " cannot be set since it is not an instance of CacheableData");
+ }
+ }
}
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java
index 8ea5ff5..8b1e42e 100644
--- a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java
+++ b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java
@@ -361,6 +361,11 @@
throws IOException {
writeMetaDataFile(mtdfile, vt, null, DataType.MATRIX, dc, fmt, formatProperties);
}
+
+ public static void writeMetaDataFile(String mtdfile, ValueType vt, DataCharacteristics dc, FileFormat fmt, FileFormatProperties formatProperties, PrivacyConstraint privacyConstraint)
+ throws IOException {
+ writeMetaDataFile(mtdfile, vt, null, DataType.MATRIX, dc, fmt, formatProperties, privacyConstraint);
+ }
public static void writeMetaDataFile(String mtdfile, ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics dc,
FileFormat fmt, FileFormatProperties formatProperties)
@@ -452,7 +457,7 @@
//add privacy constraints
if ( privacyConstraint != null ){
- mtd.put(DataExpression.PRIVACY, privacyConstraint.getPrivacy());
+ mtd.put(DataExpression.PRIVACY, privacyConstraint.getPrivacyLevel().name());
}
//add username and time
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedL2SVMTest.java
new file mode 100644
index 0000000..c93b660
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedL2SVMTest.java
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.privacy;
+
+import org.junit.Test;
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.wink.json4j.JSONException;
+
+import java.util.HashMap;
+import java.util.Map;
+
+@net.jcip.annotations.NotThreadSafe
+public class FederatedL2SVMTest extends AutomatedTestBase {
+
+ private final static String TEST_DIR = "functions/federated/";
+ private final static String TEST_NAME = "FederatedL2SVMTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedL2SVMTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ private int rows = 100;
+ private int cols = 10;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ }
+
+ // PrivateAggregation Single Input
+
+ @Test
+ public void federatedL2SVMCPPrivateAggregationX1() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateAggregationX2() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateAggregationY() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ }
+
+ // Private Single Input
+
+ @Test
+ public void federatedL2SVMCPPrivateFederatedX1() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateFederatedX2() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateFederatedY() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private));
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private);
+ }
+
+ // Setting Privacy of Matrix (Throws Exception)
+
+ @Test
+ public void federatedL2SVMCPPrivateMatrixX1() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, null, privacyConstraints, PrivacyLevel.Private, true, DMLException.class, false, null);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateMatrixX2() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, null, privacyConstraints, PrivacyLevel.Private, true, DMLException.class, false, null);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateMatrixY() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, null, privacyConstraints, PrivacyLevel.Private, true, DMLException.class, false, null);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateFederatedAndMatrixX1() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, privacyConstraints, PrivacyLevel.Private, true, DMLException.class, true, DMLException.class);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateFederatedAndMatrixX2() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, privacyConstraints, PrivacyLevel.Private, true, DMLException.class, true, DMLException.class);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateFederatedAndMatrixY() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, privacyConstraints, PrivacyLevel.Private, true, DMLException.class, false, null);
+ }
+
+ // Privacy Level Private Combinations
+
+ @Test
+ public void federatedL2SVMCPPrivateFederatedX1X2() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private));
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateFederatedX1Y() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private));
+ privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateFederatedX2Y() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private));
+ privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateFederatedX1X2Y() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private));
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private));
+ privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class);
+ }
+
+ // Privacy Level PrivateAggregation Combinations
+ @Test
+ public void federatedL2SVMCPPrivateAggregationFederatedX1X2() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateAggregationFederatedX1Y() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateAggregationFederatedX2Y() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateAggregationFederatedX1X2Y() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ }
+
+ // Privacy Level Combinations
+ @Test
+ public void federatedL2SVMCPPrivatePrivateAggregationFederatedX1X2() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private));
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivatePrivateAggregationFederatedX1Y() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private));
+ privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivatePrivateAggregationFederatedX2Y() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private));
+ privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivatePrivateAggregationFederatedYX1() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private));
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivatePrivateAggregationFederatedYX2() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("Y", new PrivacyConstraint(PrivacyLevel.Private));
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivatePrivateAggregationFederatedX2X1() throws JSONException {
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private));
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class);
+ }
+
+ // Require Federated Workers to return matrix
+
+ @Test
+ public void federatedL2SVMCPPrivateAggregationX1Exception() throws JSONException {
+ rows = 1000; cols = 1;
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation, false, null, true, DMLException.class);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateAggregationX2Exception() throws JSONException {
+ rows = 1000; cols = 1;
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.PrivateAggregation, false, null, true, DMLException.class);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateX1Exception() throws JSONException {
+ rows = 1000; cols = 1;
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.Private));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class);
+ }
+
+ @Test
+ public void federatedL2SVMCPPrivateX2Exception() throws JSONException {
+ rows = 1000; cols = 1;
+ Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
+ privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.Private));
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, null, PrivacyLevel.Private, false, null, true, DMLException.class);
+ }
+
+ private void federatedL2SVMNoException(Types.ExecMode execMode, Map<String,
+ PrivacyConstraint> privacyConstraintsFederated, Map<String, PrivacyConstraint> privacyConstraintsMatrix,
+ PrivacyLevel expectedPrivacyLevel)
+ throws JSONException
+ {
+ federatedL2SVM(execMode, privacyConstraintsFederated, privacyConstraintsMatrix, expectedPrivacyLevel, false, null, false, null);
+ }
+
+ private void federatedL2SVM(Types.ExecMode execMode, Map<String, PrivacyConstraint> privacyConstraintsFederated,
+ Map<String, PrivacyConstraint> privacyConstraintsMatrix, PrivacyLevel expectedPrivacyLevel,
+ boolean exception1, Class<?> expectedException1, boolean exception2, Class<?> expectedException2 )
+ throws JSONException
+ {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+ rtplatform = execMode;
+ if(rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ Thread t1 = null, t2 = null;
+
+ try {
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int halfRows = rows / 2;
+ // We have two matrices handled by a single federated worker
+ double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+ double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+ double[][] Y = getRandomMatrix(rows, 1, -1, 1, 1, 1233);
+ for(int i = 0; i < rows; i++)
+ Y[i][0] = (Y[i][0] > 0) ? 1 : -1;
+
+ // Write privacy constraints of normal matrix
+ if ( privacyConstraintsMatrix != null ){
+ writeInputMatrixWithMTD("MX1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols), privacyConstraintsMatrix.get("X1"));
+ writeInputMatrixWithMTD("MX2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols), privacyConstraintsMatrix.get("X2"));
+ writeInputMatrixWithMTD("MY", Y, false, new MatrixCharacteristics(rows, 1, blocksize, rows), privacyConstraintsMatrix.get("Y"));
+ } else {
+ writeInputMatrixWithMTD("MX1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("MX2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("MY", Y, false, new MatrixCharacteristics(rows, 1, blocksize, rows));
+ }
+
+ // Write privacy constraints of federated matrix
+ if ( privacyConstraintsFederated != null ){
+ writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols), privacyConstraintsFederated.get("X1"));
+ writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols), privacyConstraintsFederated.get("X2"));
+ writeInputMatrixWithMTD("Y", Y, false, new MatrixCharacteristics(rows, 1, blocksize, rows), privacyConstraintsFederated.get("Y"));
+ } else {
+ writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("Y", Y, false, new MatrixCharacteristics(rows, 1, blocksize, rows));
+ }
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ t1 = startLocalFedWorker(port1);
+ t2 = startLocalFedWorker(port2);
+
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-args", input("MX1"), input("MX2"), input("MY"), expected("Z")};
+ runTest(true, exception1, expectedException1, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-args", "\"localhost:" + port1 + "/" + input("X1") + "\"",
+ "\"localhost:" + port2 + "/" + input("X2") + "\"", Integer.toString(rows), Integer.toString(cols),
+ Integer.toString(halfRows), input("Y"), output("Z")};
+ runTest(true, exception2, expectedException2, -1);
+
+ if ( !(exception1 || exception2) ) {
+ compareResults(1e-9);
+ }
+ }
+ finally {
+ TestUtils.shutdownThreads(t1, t2);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
new file mode 100644
index 0000000..f74e3a9
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
@@ -0,0 +1,339 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.privacy;
+
+import java.util.Arrays;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.apache.sysds.common.Types;
+import static java.lang.Thread.sleep;
+
+public class FederatedWorkerHandlerTest extends AutomatedTestBase {
+
+ private static final String TEST_DIR = "functions/federated/";
+ private static final String TEST_DIR_SCALAR = TEST_DIR + "matrix_scalar/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedWorkerHandlerTest.class.getSimpleName() + "/";
+ private final static String TEST_CLASS_DIR_SCALAR = TEST_DIR_SCALAR + FederatedWorkerHandlerTest.class.getSimpleName() + "/";
+ private static final String TEST_PROG_SCALAR_ADDITION_MATRIX = "ScalarAdditionFederatedMatrix";
+ private final static String AGGREGATION_TEST_NAME = "FederatedSumTest";
+ private final static String TRANSFER_TEST_NAME = "FederatedRCBindTest";
+ private final static String MATVECMULT_TEST_NAME = "FederatedMultiplyTest";
+ private static final String FEDERATED_WORKER_HOST = "localhost";
+ private static final int FEDERATED_WORKER_PORT = 1222;
+
+ private final static int blocksize = 1024;
+ private int rows = 10;
+ private int cols = 10;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration("scalar", new TestConfiguration(TEST_CLASS_DIR_SCALAR, TEST_PROG_SCALAR_ADDITION_MATRIX, new String [] {"R"}));
+ addTestConfiguration("aggregation", new TestConfiguration(TEST_CLASS_DIR, AGGREGATION_TEST_NAME, new String[] {"S.scalar", "R", "C"}));
+ addTestConfiguration("transfer", new TestConfiguration(TEST_CLASS_DIR, TRANSFER_TEST_NAME, new String[] {"R", "C"}));
+ addTestConfiguration("matvecmult", new TestConfiguration(TEST_CLASS_DIR, MATVECMULT_TEST_NAME, new String[] {"Z"}));
+ }
+
+ @Test
+ public void scalarPrivateTest(){
+ scalarTest(PrivacyLevel.Private, DMLException.class);
+ }
+
+ @Test
+ public void scalarPrivateAggregationTest(){
+ scalarTest(PrivacyLevel.PrivateAggregation, DMLException.class);
+ }
+
+ @Test
+ public void scalarNonePrivateTest(){
+ scalarTest(PrivacyLevel.None, null);
+ }
+
+ private void scalarTest(PrivacyLevel privacyLevel, Class<?> expectedException){
+ getAndLoadTestConfiguration("scalar");
+
+ double[][] m = getRandomMatrix(this.rows, this.cols, -1, 1, 1.0, 1);
+
+ PrivacyConstraint pc = new PrivacyConstraint(privacyLevel);
+ writeInputMatrixWithMTD("M", m, false, new MatrixCharacteristics(rows, cols, blocksize, rows * cols), pc);
+
+ int s = TestUtils.getRandomInt();
+ double[][] r = new double[rows][cols];
+ for(int i = 0; i < rows; i++) {
+ for(int j = 0; j < cols; j++) {
+ r[i][j] = m[i][j] + s;
+ }
+ }
+ if (expectedException == null)
+ writeExpectedMatrix("R", r);
+
+ runGenericScalarTest(TEST_PROG_SCALAR_ADDITION_MATRIX, s, expectedException);
+ }
+
+
+ private void runGenericScalarTest(String dmlFile, int s, Class<?> expectedException)
+ {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+
+ Thread t = null;
+ try {
+ // we need the reference file to not be written to hdfs, so we get the correct format
+ rtplatform = Types.ExecMode.SINGLE_NODE;
+ if (rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ programArgs = new String[] {"-w", Integer.toString(FEDERATED_WORKER_PORT)};
+ t = new Thread(() -> runTest(true, false, null, -1));
+ t.start();
+ sleep(FED_WORKER_WAIT);
+ fullDMLScriptName = SCRIPT_DIR + TEST_DIR_SCALAR + dmlFile + ".dml";
+ programArgs = new String[]{"-args",
+ TestUtils.federatedAddress(FEDERATED_WORKER_HOST, FEDERATED_WORKER_PORT, input("M")),
+ Integer.toString(rows), Integer.toString(cols),
+ Integer.toString(s),
+ output("R")};
+ boolean exceptionExpected = (expectedException != null);
+ runTest(true, exceptionExpected, expectedException, -1);
+
+ if ( !exceptionExpected )
+ compareResults();
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ assert (false);
+ } finally {
+ rtplatform = platformOld;
+ TestUtils.shutdownThread(t);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+
+ @Test
+ public void aggregatePrivateTest() {
+ federatedSum(Types.ExecMode.SINGLE_NODE, PrivacyLevel.Private, DMLException.class);
+ }
+
+ @Test
+ public void aggregatePrivateAggregationTest() {
+ federatedSum(Types.ExecMode.SINGLE_NODE, PrivacyLevel.PrivateAggregation, null);
+ }
+
+ @Test
+ public void aggregateNonePrivateTest() {
+ federatedSum(Types.ExecMode.SINGLE_NODE, PrivacyLevel.None, null);
+ }
+
+ public void federatedSum(Types.ExecMode execMode, PrivacyLevel privacyLevel, Class<?> expectedException) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+
+ Thread t = null;
+
+ getAndLoadTestConfiguration("aggregation");
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ double[][] A = getRandomMatrix(rows, cols, -10, 10, 1, 1);
+ writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows, cols, blocksize, rows * cols), new PrivacyConstraint(privacyLevel));
+ int port = getRandomAvailablePort();
+ t = startLocalFedWorker(port);
+
+ // we need the reference file to not be written to hdfs, so we get the correct format
+ rtplatform = Types.ExecMode.SINGLE_NODE;
+ // Run reference dml script with normal matrix for Row/Col sum
+ fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-args", input("A"), input("A"), expected("R"), expected("C")};
+ runTest(true, false, null, -1);
+
+ // write expected sum
+ double sum = 0;
+ for(double[] doubles : A) {
+ sum += Arrays.stream(doubles).sum();
+ }
+ sum *= 2;
+
+ if ( expectedException == null )
+ writeExpectedScalar("S", sum);
+
+ // reference file should not be written to hdfs, so we set platform here
+ rtplatform = execMode;
+ if(rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ TestConfiguration config = availableTestConfigurations.get("aggregation");
+ loadTestConfiguration(config);
+ fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + ".dml";
+ programArgs = new String[] {"-args", "\"localhost:" + port + "/" + input("A") + "\"", Integer.toString(rows),
+ Integer.toString(cols), Integer.toString(rows * 2), output("S"), output("R"), output("C")};
+
+ runTest(true, (expectedException != null), expectedException, -1);
+
+ // compare all sums via files
+ if ( expectedException == null )
+ compareResults(1e-11);
+
+ TestUtils.shutdownThread(t);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+
+ @Test
+ public void transferPrivateTest() {
+ federatedRCBind(Types.ExecMode.SINGLE_NODE, PrivacyLevel.Private, DMLException.class);
+ }
+
+ @Test
+ public void transferPrivateAggregationTest() {
+ federatedRCBind(Types.ExecMode.SINGLE_NODE, PrivacyLevel.PrivateAggregation, DMLException.class);
+ }
+
+ @Test
+ public void transferNonePrivateTest() {
+ federatedRCBind(Types.ExecMode.SINGLE_NODE, PrivacyLevel.None, null);
+ }
+
+ public void federatedRCBind(Types.ExecMode execMode, PrivacyLevel privacyLevel, Class<?> expectedException) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+
+ Thread t = null;
+
+ getAndLoadTestConfiguration("transfer");
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ double[][] A = getRandomMatrix(rows, cols, -10, 10, 1, 1);
+ writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows, cols, blocksize, rows * cols), new PrivacyConstraint(privacyLevel));
+
+ int port = getRandomAvailablePort();
+ t = startLocalFedWorker(port);
+
+ // we need the reference file to not be written to hdfs, so we get the correct format
+ rtplatform = Types.ExecMode.SINGLE_NODE;
+ // Run reference dml script with normal matrix for Row/Col sum
+ fullDMLScriptName = HOME + TRANSFER_TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-args", input("A"), expected("R"), expected("C")};
+ runTest(true, false, null, -1);
+
+ // reference file should not be written to hdfs, so we set platform here
+ rtplatform = execMode;
+ if(rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ TestConfiguration config = availableTestConfigurations.get("transfer");
+ loadTestConfiguration(config);
+ fullDMLScriptName = HOME + TRANSFER_TEST_NAME + ".dml";
+ programArgs = new String[] {"-args", "\"localhost:" + port + "/" + input("A") + "\"", Integer.toString(rows),
+ Integer.toString(cols), output("R"), output("C")};
+
+ runTest(true, (expectedException != null), expectedException, -1);
+
+ // compare all sums via files
+ if ( expectedException == null )
+ compareResults(1e-11);
+
+ TestUtils.shutdownThread(t);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+
+ @Test
+ public void matVecMultPrivateTest() {
+ federatedMultiply(Types.ExecMode.SINGLE_NODE, PrivacyLevel.Private, DMLException.class);
+ }
+
+ @Test
+ public void matVecMultPrivateAggregationTest() {
+ federatedMultiply(Types.ExecMode.SINGLE_NODE, PrivacyLevel.PrivateAggregation, DMLException.class);
+ }
+
+ @Test
+ public void matVecMultNonePrivateTest() {
+ federatedMultiply(Types.ExecMode.SINGLE_NODE, PrivacyLevel.None, null);
+ }
+
+ public void federatedMultiply(Types.ExecMode execMode, PrivacyLevel privacyLevel, Class<?> expectedException) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+ rtplatform = execMode;
+ if(rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+
+ Thread t1, t2;
+
+ getAndLoadTestConfiguration("matvecmult");
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int halfRows = rows / 2;
+ // We have two matrices handled by a single federated worker
+ double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+ double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+ // And another two matrices handled by a single federated worker
+ double[][] Y1 = getRandomMatrix(cols, halfRows, 0, 1, 1, 44);
+ double[][] Y2 = getRandomMatrix(cols, halfRows, 0, 1, 1, 21);
+
+ writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols), new PrivacyConstraint(privacyLevel));
+ writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("Y1", Y1, false, new MatrixCharacteristics(cols, halfRows, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("Y2", Y2, false, new MatrixCharacteristics(cols, halfRows, blocksize, halfRows * cols));
+
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ t1 = startLocalFedWorker(port1);
+ t2 = startLocalFedWorker(port2);
+
+ TestConfiguration config = availableTestConfigurations.get("matvecmult");
+ loadTestConfiguration(config);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + MATVECMULT_TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+ "Y2=" + input("Y2"), "Z=" + expected("Z")};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + MATVECMULT_TEST_NAME + ".dml";
+ programArgs = new String[] {"-nvargs",
+ "X1=" + TestUtils.federatedAddress("localhost", port1, input("X1")),
+ "X2=" + TestUtils.federatedAddress("localhost", port2, input("X2")),
+ "Y1=" + TestUtils.federatedAddress("localhost", port1, input("Y1")),
+ "Y2=" + TestUtils.federatedAddress("localhost", port2, input("Y2")), "r=" + rows, "c=" + cols,
+ "hr=" + halfRows, "Z=" + output("Z")};
+ runTest(true, (expectedException != null), expectedException, -1);
+
+ // compare via files
+ if (expectedException == null)
+ compareResults(1e-9);
+
+ TestUtils.shutdownThreads(t1, t2);
+
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/MatrixMultiplicationPropagationTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/MatrixMultiplicationPropagationTest.java
index a16355a..0715b0a 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/MatrixMultiplicationPropagationTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/MatrixMultiplicationPropagationTest.java
@@ -27,6 +27,7 @@
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
@@ -46,26 +47,36 @@
}
@Test
- public void testMatrixMultiplicationPropagation() throws JSONException {
- matrixMultiplicationPropagation(true, true);
+ public void testMatrixMultiplicationPropagationPrivate() throws JSONException {
+ matrixMultiplicationPropagation(PrivacyLevel.Private, true);
}
@Test
- public void testMatrixMultiplicationPropagationFalse() throws JSONException {
- matrixMultiplicationPropagation(false, true);
+ public void testMatrixMultiplicationPropagationNone() throws JSONException {
+ matrixMultiplicationPropagation(PrivacyLevel.None, true);
}
@Test
- public void testMatrixMultiplicationPropagationSecondOperand() throws JSONException {
- matrixMultiplicationPropagation(true, false);
+ public void testMatrixMultiplicationPropagationPrivateAggregation() throws JSONException {
+ matrixMultiplicationPropagation(PrivacyLevel.PrivateAggregation, true);
}
@Test
- public void testMatrixMultiplicationPropagationSecondOperandFalse() throws JSONException {
- matrixMultiplicationPropagation(false, false);
+ public void testMatrixMultiplicationPropagationSecondOperandPrivate() throws JSONException {
+ matrixMultiplicationPropagation(PrivacyLevel.Private, false);
}
- private void matrixMultiplicationPropagation(boolean privacy, boolean privateFirstOperand) throws JSONException {
+ @Test
+ public void testMatrixMultiplicationPropagationSecondOperandNone() throws JSONException {
+ matrixMultiplicationPropagation(PrivacyLevel.None, false);
+ }
+
+ @Test
+ public void testMatrixMultiplicationPropagationSecondOperandPrivateAggregation() throws JSONException {
+ matrixMultiplicationPropagation(PrivacyLevel.PrivateAggregation, false);
+ }
+
+ private void matrixMultiplicationPropagation(PrivacyLevel privacyLevel, boolean privateFirstOperand) throws JSONException {
TestConfiguration config = availableTestConfigurations.get("MatrixMultiplicationPropagationTest");
loadTestConfiguration(config);
@@ -78,7 +89,7 @@
double[][] b = getRandomMatrix(n, k, -1, 1, 1, -1);
double[][] c = TestUtils.performMatrixMultiplication(a, b);
- PrivacyConstraint privacyConstraint = new PrivacyConstraint(privacy);
+ PrivacyConstraint privacyConstraint = new PrivacyConstraint(privacyLevel);
MatrixCharacteristics dataCharacteristics = new MatrixCharacteristics(m,n,k,k);
if ( privateFirstOperand ) {
@@ -99,7 +110,7 @@
// Check that the output metadata is correct
String actualPrivacyValue = readDMLMetaDataValue("c", OUTPUT_DIR, DataExpression.PRIVACY);
- assertEquals(String.valueOf(privacy), actualPrivacyValue);
+ assertEquals(String.valueOf(privacyLevel), actualPrivacyValue);
}
@Test
@@ -144,28 +155,32 @@
}
@Test
- public void testMatrixMultiplicationPrivacyInputTrue() throws JSONException {
- testMatrixMultiplicationPrivacyInput(true);
+ public void testMatrixMultiplicationPrivacyInputPrivate() throws JSONException {
+ testMatrixMultiplicationPrivacyInput(PrivacyLevel.Private);
}
@Test
- public void testMatrixMultiplicationPrivacyInputFalse() throws JSONException {
- testMatrixMultiplicationPrivacyInput(false);
+ public void testMatrixMultiplicationPrivacyInputNone() throws JSONException {
+ testMatrixMultiplicationPrivacyInput(PrivacyLevel.None);
}
- private void testMatrixMultiplicationPrivacyInput(boolean privacy) throws JSONException {
+ @Test
+ public void testMatrixMultiplicationPrivacyInputPrivateAggregation() throws JSONException {
+ testMatrixMultiplicationPrivacyInput(PrivacyLevel.PrivateAggregation);
+ }
+
+ private void testMatrixMultiplicationPrivacyInput(PrivacyLevel privacyLevel) throws JSONException {
TestConfiguration config = availableTestConfigurations.get("MatrixMultiplicationPropagationTest");
loadTestConfiguration(config);
double[][] a = getRandomMatrix(m, n, -1, 1, 1, -1);
- PrivacyConstraint privacyConstraint = new PrivacyConstraint();
- privacyConstraint.setPrivacy(privacy);
+ PrivacyConstraint privacyConstraint = new PrivacyConstraint(privacyLevel);
MatrixCharacteristics dataCharacteristics = new MatrixCharacteristics(m,n,k,k);
writeInputMatrixWithMTD("a", a, false, dataCharacteristics, privacyConstraint);
String actualPrivacyValue = readDMLMetaDataValue("a", INPUT_DIR, DataExpression.PRIVACY);
- assertEquals(String.valueOf(privacy), actualPrivacyValue);
+ assertEquals(String.valueOf(privacyLevel), actualPrivacyValue);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/MatrixRuntimePropagationTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/MatrixRuntimePropagationTest.java
new file mode 100644
index 0000000..a72ea32
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/MatrixRuntimePropagationTest.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.privacy;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.wink.json4j.JSONException;
+import org.junit.Test;
+
+public class MatrixRuntimePropagationTest extends AutomatedTestBase
+{
+ private static final String TEST_DIR = "functions/privacy/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + MatrixMultiplicationPropagationTest.class.getSimpleName() + "/";
+ private final int m = 20;
+ private final int n = 20;
+ private final int k = 20;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration("MatrixRuntimePropagationTest",
+ new TestConfiguration(TEST_CLASS_DIR, "MatrixRuntimePropagationTest", new String[]{"c"}));
+ }
+
+ @Test
+ public void testRuntimePropagationPrivate() throws JSONException {
+ conditionalPropagation(PrivacyLevel.Private);
+ }
+
+ @Test
+ public void testRuntimePropagationNone() throws JSONException {
+ conditionalPropagation(PrivacyLevel.None);
+ }
+
+ @Test
+ public void testRuntimePropagationPrivateAggregation() throws JSONException {
+ conditionalPropagation(PrivacyLevel.PrivateAggregation);
+ }
+
+ private void conditionalPropagation(PrivacyLevel privacyLevel) throws JSONException {
+
+ TestConfiguration config = availableTestConfigurations.get("MatrixRuntimePropagationTest");
+ loadTestConfiguration(config);
+ fullDMLScriptName = SCRIPT_DIR + TEST_DIR + config.getTestScript() + ".dml";
+
+ double[][] a = getRandomMatrix(m, n, -1, 1, 1, -1);
+ double[][] b = getRandomMatrix(n, k, -1, 1, 1, -1);
+ double sum;
+
+ PrivacyConstraint privacyConstraint = new PrivacyConstraint(privacyLevel);
+ MatrixCharacteristics dataCharacteristics = new MatrixCharacteristics(m,n,k,k);
+
+ writeInputMatrixWithMTD("a", a, false, dataCharacteristics, privacyConstraint);
+ writeInputMatrix("b", b);
+ if ( privacyLevel.equals(PrivacyLevel.Private) || privacyLevel.equals(PrivacyLevel.PrivateAggregation) ){
+ writeExpectedMatrix("c", a);
+ sum = TestUtils.sum(a, m, n) + 1;
+ } else {
+ writeExpectedMatrix("c", b);
+ sum = TestUtils.sum(a, m, n) - 1;
+ }
+
+ programArgs = new String[]{"-nvargs",
+ "a=" + input("a"), "b=" + input("b"), "c=" + output("c"),
+ "m=" + m, "n=" + n, "k=" + k, "s=" + sum };
+
+ runTest(true,false,null,-1);
+
+ // Check that the output data is correct
+ compareResults(1e-9);
+
+ // Check that the output metadata is correct
+ if ( privacyLevel.equals(PrivacyLevel.Private) ) {
+ String actualPrivacyValue = readDMLMetaDataValue("c", OUTPUT_DIR, DataExpression.PRIVACY);
+ assertEquals(PrivacyLevel.Private.name(), actualPrivacyValue);
+ }
+ else if ( privacyLevel.equals(PrivacyLevel.PrivateAggregation) ){
+ String actualPrivacyValue = readDMLMetaDataValue("c", OUTPUT_DIR, DataExpression.PRIVACY);
+ assertEquals(PrivacyLevel.PrivateAggregation.name(), actualPrivacyValue);
+ }
+ else {
+ // Check that a JSONException is thrown
+ // or that privacy level is set to none
+ // because no privacy metadata should be written to c
+ // except if the privacy written is set to private
+ boolean JSONExceptionThrown = false;
+ String actualPrivacyValue = null;
+ try{
+ actualPrivacyValue = readDMLMetaDataValue("c", OUTPUT_DIR, DataExpression.PRIVACY);
+ } catch (JSONException e){
+ JSONExceptionThrown = true;
+ } catch (Exception e){
+ fail("Exception occured, but JSONException was expected. The exception thrown is: " + e.getMessage());
+ e.printStackTrace();
+ }
+ assert(JSONExceptionThrown || (PrivacyLevel.None.name().equals(actualPrivacyValue)));
+ }
+ }
+}
diff --git a/src/test/scripts/functions/privacy/MatrixRuntimePropagationTest.dml b/src/test/scripts/functions/privacy/MatrixRuntimePropagationTest.dml
new file mode 100644
index 0000000..b51cbf3
--- /dev/null
+++ b/src/test/scripts/functions/privacy/MatrixRuntimePropagationTest.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($a, rows=$m, cols=$n, format="text");
+B = read($b, rows=$n, cols=$k, format="text");
+if ( sum(A) < $s){
+ write(A, $c, format="text");
+} else {
+ write(B, $c, format="text");
+}
\ No newline at end of file