[SYSTEMDS-2568] Privacy Runtime Extended
Add FederatedWorkerHandlerException And Improved Handling of
Exceptions in FederatedWorkerHandler
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 8accfea..c87490a 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -47,6 +47,7 @@
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaData;
import org.apache.sysds.runtime.meta.MetaDataFormat;
+import org.apache.sysds.runtime.privacy.CheckedConstraintsLog;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.runtime.util.LocalFileUtils;
@@ -322,6 +323,8 @@
public void setPrivacyConstraints(PrivacyConstraint pc) {
_privacyConstraint = pc;
+ if ( DMLScript.CHECK_PRIVACY && pc != null )
+ CheckedConstraintsLog.addLoadedConstraint(pc.getPrivacyLevel());
}
public PrivacyConstraint getPrivacyConstraint() {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index 5848e9e..f850e73 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -19,6 +19,13 @@
package org.apache.sysds.runtime.controlprogram.caching;
+import static org.apache.sysds.runtime.util.UtilFunctions.requestFederatedData;
+
+import java.io.IOException;
+import java.lang.ref.SoftReference;
+import java.util.List;
+import java.util.concurrent.Future;
+
import org.apache.commons.lang.mutable.MutableBoolean;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
@@ -41,18 +48,10 @@
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaData;
import org.apache.sysds.runtime.meta.MetaDataFormat;
-import org.apache.sysds.runtime.privacy.DMLPrivacyException;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.runtime.util.IndexRange;
-import java.io.IOException;
-import java.lang.ref.SoftReference;
-import java.util.List;
-import java.util.concurrent.Future;
-
-import static org.apache.sysds.runtime.util.UtilFunctions.requestFederatedData;
-
/**
* Represents a matrix in control program. This class contains method to read
* matrices from HDFS and convert them to a specific format/representation. It
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index ce07488..771f828 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -35,7 +35,7 @@
private FedMethod _method;
private List<Object> _data;
- private boolean checkPrivacy;
+ private boolean _checkPrivacy;
public FederatedRequest(FedMethod method, List<Object> data) {
_method = method;
@@ -82,7 +82,7 @@
}
public void setCheckPrivacy(boolean checkPrivacy){
- this.checkPrivacy = checkPrivacy;
+ this._checkPrivacy = checkPrivacy;
}
public void setCheckPrivacy(){
@@ -90,6 +90,6 @@
}
public boolean checkPrivacy(){
- return checkPrivacy;
+ return _checkPrivacy;
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
index c187051..6c79457 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
@@ -65,7 +65,11 @@
}
public String getErrorMessage() {
- return ExceptionUtils.getFullStackTrace( (Exception) _data[0] );
+ if (_data[0] instanceof Throwable )
+ return ExceptionUtils.getFullStackTrace( (Throwable) _data[0] );
+ else if (_data[0] instanceof String)
+ return (String) _data[0];
+ else return "No readable error message";
}
public Object[] getData() throws Exception {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 27e20e2..3c688f5 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -48,6 +48,7 @@
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
+import org.apache.sysds.runtime.privacy.DMLPrivacyException;
import org.apache.sysds.runtime.privacy.PrivacyMonitor;
import org.apache.sysds.runtime.privacy.PrivacyPropagator;
import org.apache.sysds.utils.JSONHelper;
@@ -114,12 +115,17 @@
return executeScalarOperation(request);
default:
String message = String.format("Method %s is not supported.", method);
- return new FederatedResponse(FederatedResponse.Type.ERROR, message);
+ return new FederatedResponse(FederatedResponse.Type.ERROR, new FederatedWorkerHandlerException(message));
}
}
- catch (Exception exception) {
+ catch (DMLPrivacyException | FederatedWorkerHandlerException exception) {
return new FederatedResponse(FederatedResponse.Type.ERROR, exception);
}
+ catch (Exception exception) {
+ return new FederatedResponse(FederatedResponse.Type.ERROR,
+ new FederatedWorkerHandlerException("Exception of type "
+ + exception.getClass() + " thrown when processing request"));
+ }
}
private FederatedResponse readData(FederatedRequest request, Types.DataType dataType) {
@@ -141,7 +147,8 @@
break;
default:
// should NEVER happen (if we keep request codes in sync with actual behaviour)
- return new FederatedResponse(FederatedResponse.Type.ERROR, "Could not recognize datatype");
+ return new FederatedResponse(FederatedResponse.Type.ERROR,
+ new FederatedWorkerHandlerException("Could not recognize datatype"));
}
// read metadata
@@ -153,7 +160,7 @@
try (BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) {
JSONObject mtd = JSONHelper.parse(br);
if (mtd == null)
- return new FederatedResponse(FederatedResponse.Type.ERROR, "Could not parse metadata file");
+ return new FederatedResponse(FederatedResponse.Type.ERROR, new FederatedWorkerHandlerException("Could not parse metadata file"));
mc.setRows(mtd.getLong(DataExpression.READROWPARAM));
mc.setCols(mtd.getLong(DataExpression.READCOLPARAM));
cd = PrivacyPropagator.parseAndSetPrivacyConstraint(cd, mtd);
@@ -224,7 +231,7 @@
// TODO rest of the possible datatypes
default:
return new FederatedResponse(FederatedResponse.Type.ERROR,
- "FederatedWorkerHandler: Not possible to send datatype " + dataObject.getDataType().name());
+ new FederatedWorkerHandlerException("Not possible to send datatype " + dataObject.getDataType().name()));
}
}
@@ -238,9 +245,9 @@
private FederatedResponse executeAggregation(long varID, AggregateUnaryOperator operator) {
Data dataObject = _vars.get(varID);
if (dataObject.getDataType() != Types.DataType.MATRIX) {
- return new FederatedResponse(FederatedResponse.Type.ERROR,
- "FederatedWorkerHandler: Aggregation only supported for matrices, not for "
- + dataObject.getDataType().name());
+ return new FederatedResponse(FederatedResponse.Type.ERROR,
+ new FederatedWorkerHandlerException("Aggregation only supported for matrices, not for "
+ + dataObject.getDataType().name()));
}
MatrixObject matrixObject = (MatrixObject) dataObject;
matrixObject = PrivacyMonitor.handlePrivacy(matrixObject);
@@ -261,12 +268,7 @@
outNumCols += numMissing;
}
MatrixBlock ret = new MatrixBlock(outNumRows, outNumCols, operator.aggOp.initialValue);
- try {
- LibMatrixAgg.aggregateUnaryMatrix(matrixBlock, ret, operator);
- }
- catch (Exception e) {
- return new FederatedResponse(FederatedResponse.Type.ERROR, "FederatedWorkerHandler: " + e);
- }
+ LibMatrixAgg.aggregateUnaryMatrix(matrixBlock, ret, operator);
// result block without correction
ret.dropLastRowsOrColumns(operator.aggOp.correction);
return new FederatedResponse(FederatedResponse.Type.SUCCESS, ret);
@@ -284,8 +286,8 @@
dataObject = PrivacyMonitor.handlePrivacy(dataObject);
if (dataObject.getDataType() != Types.DataType.MATRIX) {
return new FederatedResponse(FederatedResponse.Type.ERROR,
- "FederatedWorkerHandler: ScalarOperator dont support "
- + dataObject.getDataType().name());
+ new FederatedWorkerHandlerException("FederatedWorkerHandler: ScalarOperator dont support "
+ + dataObject.getDataType().name()));
}
MatrixObject matrixObject = (MatrixObject) dataObject;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandlerException.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandlerException.java
new file mode 100644
index 0000000..79c1a6b
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandlerException.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated;
+
+/**
+ * Exception to throw when an exception occurs in FederatedWorkerHandler during handling of FederatedRequest. The
+ * purpose of FederatedWorkerHandlerException is to propagate useful information from the federated workers to the
+ * federated master without exposing details that are usually included in exceptions, for instance name of files that
+ * were not found or data points that could not be handled correctly.
+ */
+public class FederatedWorkerHandlerException extends RuntimeException {
+
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * Create new instance of FederatedWorkerHandlerException with a message.
+ *
+ * @param msg message describing the exception
+ */
+ public FederatedWorkerHandlerException(String msg) {
+ super(msg);
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/privacy/CheckedConstraintsLog.java b/src/main/java/org/apache/sysds/runtime/privacy/CheckedConstraintsLog.java
index e6bc7c0..9f11300 100644
--- a/src/main/java/org/apache/sysds/runtime/privacy/CheckedConstraintsLog.java
+++ b/src/main/java/org/apache/sysds/runtime/privacy/CheckedConstraintsLog.java
@@ -26,7 +26,15 @@
import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
+/**
+ * Class counting the checked privacy constraints and the loaded privacy constraints.
+ */
public class CheckedConstraintsLog {
+ private static Map<PrivacyLevel,LongAdder> loadedConstraintsTotal = new EnumMap<PrivacyLevel,LongAdder>(PrivacyLevel.class);
+ static {
+ for ( PrivacyLevel level : PrivacyLevel.values() )
+ loadedConstraintsTotal.put(level, new LongAdder());
+ }
private static Map<PrivacyLevel,LongAdder> checkedConstraintsTotal = new EnumMap<PrivacyLevel,LongAdder>(PrivacyLevel.class);
private static BiFunction<LongAdder, LongAdder, LongAdder> mergeLongAdders = (v1, v2) -> {
v1.add(v2.longValue() );
@@ -45,23 +53,40 @@
}
/**
- * Remove all elements from checked constraints log.
+ * Add an occurence of the given privacy level to the loaded constraints log total.
+ * @param level privacy level from loaded privacy constraint
+ */
+ public static void addLoadedConstraint(PrivacyLevel level){
+ if (level != null)
+ loadedConstraintsTotal.get(level).increment();
+ }
+
+ /**
+ * Remove all elements from checked constraints log and loaded constraints log.
*/
public static void reset(){
checkedConstraintsTotal.clear();
+ loadedConstraintsTotal.replaceAll((k,v)->new LongAdder());
}
public static Map<PrivacyLevel,LongAdder> getCheckedConstraints(){
return checkedConstraintsTotal;
}
+ public static Map<PrivacyLevel, LongAdder> getLoadedConstraints(){
+ return loadedConstraintsTotal;
+ }
+
/**
* Get string representing all contents of the checked constraints log.
* @return string representation of checked constraints log.
*/
public static String display(){
StringBuilder sb = new StringBuilder();
+ sb.append("Checked Privacy Constraints:\n");
checkedConstraintsTotal.forEach((k,v)->sb.append("\t" + k + ": " + v + "\n"));
+ sb.append("Loaded Privacy Constraints:\n");
+ loadedConstraintsTotal.forEach((k,v)->sb.append("\t" + k + ": " + v + "\n"));
return sb.toString();
}
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java
index 3978b6d..779f464 100644
--- a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java
+++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java
@@ -20,7 +20,6 @@
package org.apache.sysds.runtime.privacy;
import java.util.EnumMap;
-import java.util.HashMap;
import java.util.concurrent.atomic.LongAdder;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java
index 298b1a7..c639441 100644
--- a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java
+++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java
@@ -365,4 +365,4 @@
// if privacy level is PrivateAggregation and data is scalar, the call should pass without propagating any constraints
}
}
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java
index 7130004..cc813a4 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -999,7 +999,7 @@
}
if (DMLScript.CHECK_PRIVACY)
- sb.append("Checked Privacy Constraints:\n" + CheckedConstraintsLog.display());
+ sb.append(CheckedConstraintsLog.display());
return sb.toString();
}
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/BuiltinGLMTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/BuiltinGLMTest.java
new file mode 100644
index 0000000..a2b2f29
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/BuiltinGLMTest.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.privacy;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+/**
+ * Adapted from org.apache.sysds.test.functions.builtin.BuiltinGLMTest.
+ * Different privacy constraints are added to the input.
+ */
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class BuiltinGLMTest extends AutomatedTestBase
+{
+ protected final static String TEST_NAME = "glmTest";
+ protected final static String TEST_DIR = "functions/builtin/";
+ protected String TEST_CLASS_DIR = TEST_DIR + BuiltinGLMTest.class.getSimpleName() + "/";
+ double eps = 1e-4;
+
+ protected int numRecords, numFeatures, distFamilyType, linkType, intercept;
+ protected double distParam, linkPower, logFeatureVarianceDisbalance, avgLinearForm, stdevLinearForm, dispersion;
+
+ public BuiltinGLMTest(int numRecords_, int numFeatures_, int distFamilyType_, double distParam_,
+ int linkType_, double linkPower_, double logFeatureVarianceDisbalance_,
+ double avgLinearForm_, double stdevLinearForm_, double dispersion_)
+ {
+ this.numRecords = numRecords_;
+ this.numFeatures = numFeatures_;
+ this.distFamilyType = distFamilyType_;
+ this.distParam = distParam_;
+ this.linkType = linkType_;
+ this.linkPower = linkPower_;
+ this.logFeatureVarianceDisbalance = logFeatureVarianceDisbalance_;
+ this.avgLinearForm = avgLinearForm_;
+ this.stdevLinearForm = stdevLinearForm_;
+ this.dispersion = dispersion_;
+ }
+
+ private void setIntercept(int intercept_)
+ {
+ intercept = intercept_/100;
+ }
+
+ @Override
+ public void setUp()
+ {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_CLASS_DIR, TEST_NAME);
+ }
+
+ // Private
+ @Test
+ public void glmTestIntercept_0_CP_Private() {
+ setIntercept(0);
+ runtestGLM(new PrivacyConstraint(PrivacyLevel.Private), DMLException.class);
+ }
+
+ // PrivateAggregation
+ @Test
+ public void glmTestIntercept_0_CP_PrivateAggregation() {
+ setIntercept(0);
+ runtestGLM(new PrivacyConstraint(PrivacyLevel.PrivateAggregation), null);
+ }
+
+ // None
+ @Test
+ public void glmTestIntercept_0_CP_None() {
+ setIntercept(0);
+ runtestGLM(new PrivacyConstraint(PrivacyLevel.None), null);
+ }
+
+ public void runtestGLM(PrivacyConstraint privacyConstraint, Class<?> expectedException) {
+ Types.ExecMode platformOld = setExecMode(LopProperties.ExecType.CP);
+ try {
+ int rows = numRecords; // # of rows in the training data
+ int cols = numFeatures; // # of features in the training data
+ System.out.println("------------ BEGIN " + TEST_NAME + " TEST WITH {" + rows + ", " + cols
+ + ", " + distFamilyType + ", " + distParam + ", " + linkType + ", " + linkPower + ", "
+ + intercept + ", " + logFeatureVarianceDisbalance + ", " + avgLinearForm + ", " + stdevLinearForm
+ + ", " + dispersion + "} ------------");
+
+ TestUtils.GLMDist glmdist = new TestUtils.GLMDist(distFamilyType, distParam, linkType, linkPower);
+ glmdist.set_dispersion(dispersion);
+
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ // prepare training data set
+ Random r = new Random(314159265);
+ double[][] X = TestUtils.generateUnbalancedGLMInputDataX(rows, cols, logFeatureVarianceDisbalance);
+ double[] beta = TestUtils.generateUnbalancedGLMInputDataB(X, cols, intercept, avgLinearForm, stdevLinearForm, r);
+ double[][] y = TestUtils.generateUnbalancedGLMInputDataY(X, beta, rows, cols, glmdist, intercept, dispersion, r);
+
+ int defaultBlockSize = OptimizerUtils.DEFAULT_BLOCKSIZE;
+
+ MatrixCharacteristics mc_X = new MatrixCharacteristics(rows, cols, defaultBlockSize, -1);
+ writeInputMatrixWithMTD("X", X, true, mc_X, privacyConstraint);
+
+ MatrixCharacteristics mc_y = new MatrixCharacteristics(rows, y[0].length, defaultBlockSize, -1);
+ writeInputMatrixWithMTD("Y", y, true, mc_y, privacyConstraint);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ List<String> proArgs = new ArrayList<>();
+ proArgs.add("-exec");
+ proArgs.add(" singlenode");
+ proArgs.add("-nvargs");
+ proArgs.add("X=" + input("X"));
+ proArgs.add("Y=" + input("Y"));
+ proArgs.add("dfam=" + String.valueOf(distFamilyType));
+ proArgs.add(((distFamilyType == 2 && distParam != 1.0) ? "yneg=" : "vpow=") + String.valueOf(distParam));
+ proArgs.add((distFamilyType == 2 && distParam != 1.0) ? "vpow=0.0" : "yneg=0.0");
+ proArgs.add("link=" + String.valueOf(linkType));
+ proArgs.add("lpow=" + String.valueOf(linkPower));
+ proArgs.add("icpt=" + String.valueOf(intercept)); // INTERCEPT - CHANGE THIS AS NEEDED
+ proArgs.add("disp=0.0"); // DISPERSION (0.0: ESTIMATE)
+ proArgs.add("reg=0.0"); // LAMBDA REGULARIZER
+ proArgs.add("tol=0.000000000001"); // TOLERANCE (EPSILON)
+ proArgs.add("moi=300");
+ proArgs.add("mii=0");
+ proArgs.add("B=" + output("betas_SYSTEMDS"));
+ programArgs = proArgs.toArray(new String[proArgs.size()]);
+
+ fullRScriptName = HOME + TEST_NAME + ".R";
+ rCmd = getRCmd(input("X.mtx"), input("Y.mtx"),
+ String.valueOf(distFamilyType),
+ String.valueOf(distParam),
+ String.valueOf(linkType),
+ String.valueOf(linkPower),
+ String.valueOf(intercept),
+ "0.000000000001",
+ expected("betas_R"));
+
+ runTest(true, (expectedException != null), expectedException, -1);
+
+ if ( expectedException == null ){
+
+ double max_abs_beta = 0.0;
+ HashMap<MatrixValue.CellIndex, Double> wTRUE = new HashMap<>();
+ for (int j = 0; j < cols; j++) {
+ wTRUE.put(new MatrixValue.CellIndex(j + 1, 1), Double.valueOf(beta[j]));
+ max_abs_beta = (max_abs_beta >= Math.abs(beta[j]) ? max_abs_beta : Math.abs(beta[j]));
+ }
+
+ HashMap<MatrixValue.CellIndex, Double> wSYSTEMDS_raw = readDMLMatrixFromHDFS("betas_SYSTEMDS");
+ HashMap<MatrixValue.CellIndex, Double> wSYSTEMDS = new HashMap<>();
+ for (MatrixValue.CellIndex key : wSYSTEMDS_raw.keySet())
+ if (key.column == 1)
+ wSYSTEMDS.put(key, wSYSTEMDS_raw.get(key));
+
+ runRScript(true);
+
+ HashMap<MatrixValue.CellIndex, Double> wR = readRMatrixFromFS("betas_R");
+
+ if ((distParam == 0 && linkType == 1)) { // Gaussian.*
+ //NOTE MB: Gaussian.log was the only test failing when we introduced multi-threaded
+ //matrix multplications (mmchain). After discussions with Sasha, we decided to change the eps
+ //because accuracy is anyway affected by various rewrites like binary to unary (-1*x->-x),
+ //transpose-matrixmult, and dot product sum. Disabling these rewrites led to a successful
+ //test result. Even without multi-threaded matrix mult this test was failing for different number
+ //of rows if these rewrites are enabled. Users can turn off rewrites if high accuracy is required.
+ //However, in the future we might also consider to use Kahan plus for aggregations in matrix mult
+ //(at least for the final aggregation of partial results from individual threads).
+
+ //NOTE MB: similar issues occurred with other tests when moving to github action tests
+ eps *= (linkPower == -1) ? 4 : 2; //Gaussian.inverse vs Gaussian.*;
+ }
+ TestUtils.compareMatrices(wR, wSYSTEMDS, eps * max_abs_beta, "wR", "wSYSTEMDS");
+ }
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // SCHEMA:
+ // #RECORDS, #FEATURES, DISTRIBUTION_FAMILY, VARIANCE_POWER or BERNOULLI_NO, LINK_TYPE, LINK_POWER,
+ // LOG_FEATURE_VARIANCE_DISBALANCE, AVG_LINEAR_FORM, ST_DEV_LINEAR_FORM, DISPERSION
+ Object[][] data = new Object[][] {
+ // #RECS #FTRS DFM VPOW LNK LPOW LFVD AVGLT STDLT DISP
+ // Both DML and R work and compute close results:
+ { 10000, 50, 1, 0.0, 1, 0.0, 3.0, 10.0, 2.0, 2.5 }, // Gaussian.log
+ { 1000, 100, 1, 1.0, 1, 0.0, 3.0, 0.0, 1.0, 2.5 }, // Poisson.log
+ { 10000, 50, 1, 2.0, 1, 0.0, 3.0, 0.0, 2.0, 2.5 }, // Gamma.log
+
+ { 10000, 50, 2, -1.0, 1, 0.0, 3.0, -5.0, 1.0, 1.0 }, // Bernoulli {-1, 1}.log // Note: Y is sparse
+ { 1000, 100, 2, -1.0, 2, 0.0, 3.0, 0.0, 2.0, 1.0 }, // Bernoulli {-1, 1}.logit
+ { 2000, 100, 2, -1.0, 3, 0.0, 3.0, 0.0, 2.0, 1.0 }, // Bernoulli {-1, 1}.probit
+
+ { 10000, 50, 2, 1.0, 1, 0.0, 3.0, -5.0, 1.0, 2.5 }, // Binomial two-column.log // Note: Y is sparse
+ { 1000, 100, 2, 1.0, 2, 0.0, 3.0, 0.0, 2.0, 2.5 }, // Binomial two-column.logit
+ { 2000, 100, 2, 1.0, 3, 0.0, 3.0, 0.0, 2.0, 2.5 }, // Binomial two-column.probit
+ };
+ return Arrays.asList(data);
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/CheckedConstraintsLogTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/CheckedConstraintsLogTest.java
index d187436..ae0a5f9 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/CheckedConstraintsLogTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/CheckedConstraintsLogTest.java
@@ -19,6 +19,9 @@
package org.apache.sysds.test.functions.privacy;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
import java.util.EnumMap;
import java.util.concurrent.atomic.LongAdder;
@@ -32,27 +35,27 @@
@Override
public void setUp() {
- CheckedConstraintsLog.getCheckedConstraints().clear();
+ CheckedConstraintsLog.reset();
}
@Test
public void addCheckedConstraintsNull(){
CheckedConstraintsLog.addCheckedConstraints(null);
- assert(CheckedConstraintsLog.getCheckedConstraints() != null && CheckedConstraintsLog.getCheckedConstraints().isEmpty());
+ assertTrue(CheckedConstraintsLog.getCheckedConstraints() != null && CheckedConstraintsLog.getCheckedConstraints().isEmpty());
}
@Test
public void addCheckedConstraintsEmpty(){
EnumMap<PrivacyLevel,LongAdder> checked = new EnumMap<>(PrivacyLevel.class);
CheckedConstraintsLog.addCheckedConstraints(checked);
- assert(CheckedConstraintsLog.getCheckedConstraints() != null && CheckedConstraintsLog.getCheckedConstraints().isEmpty());
+ assertTrue(CheckedConstraintsLog.getCheckedConstraints() != null && CheckedConstraintsLog.getCheckedConstraints().isEmpty());
}
@Test
public void addCheckedConstraintsSingleValue(){
EnumMap<PrivacyLevel,LongAdder> checked = getMap(PrivacyLevel.Private, 300);
CheckedConstraintsLog.addCheckedConstraints(checked);
- assert(CheckedConstraintsLog.getCheckedConstraints().get(PrivacyLevel.Private).longValue() == 300);
+ assertTrue(CheckedConstraintsLog.getCheckedConstraints().get(PrivacyLevel.Private).longValue() == 300);
}
@Test
@@ -61,7 +64,7 @@
CheckedConstraintsLog.addCheckedConstraints(checked);
EnumMap<PrivacyLevel,LongAdder> checked2 = getMap(PrivacyLevel.Private, 150);
CheckedConstraintsLog.addCheckedConstraints(checked2);
- assert(CheckedConstraintsLog.getCheckedConstraints().get(PrivacyLevel.Private).longValue() == 450);
+ assertTrue(CheckedConstraintsLog.getCheckedConstraints().get(PrivacyLevel.Private).longValue() == 450);
}
@Test
@@ -72,7 +75,7 @@
CheckedConstraintsLog.addCheckedConstraints(checked2);
EnumMap<PrivacyLevel,LongAdder> checked3 = getMap(PrivacyLevel.PrivateAggregation, 150);
CheckedConstraintsLog.addCheckedConstraints(checked3);
- assert(CheckedConstraintsLog.getCheckedConstraints().get(PrivacyLevel.Private).longValue() == 450
+ assertTrue(CheckedConstraintsLog.getCheckedConstraints().get(PrivacyLevel.Private).longValue() == 450
&& CheckedConstraintsLog.getCheckedConstraints().get(PrivacyLevel.PrivateAggregation).longValue() == 150);
}
@@ -83,4 +86,12 @@
checked.put(level, valueAdder);
return checked;
}
-}
\ No newline at end of file
+
+ @Test
+ public void addLoadedConstraintsSingleValue(){
+ Integer n = 12;
+ for (int i = 0; i < n; i++)
+ CheckedConstraintsLog.addLoadedConstraint(PrivacyLevel.Private);
+ assertEquals(n.longValue(), CheckedConstraintsLog.getLoadedConstraints().get(PrivacyLevel.Private).longValue());
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
index 129b3b8..8698448 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
@@ -345,4 +345,4 @@
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
-}
\ No newline at end of file
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/GLMTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/GLMTest.java
index 271a711..69fc2dc 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/GLMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/GLMTest.java
@@ -39,6 +39,11 @@
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestUtils;
+/**
+ * Adapted from org.apache.sysds.test.applications.GLMTest.
+ * Different privacy constraints are added to the input.
+ */
+
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class GLMTest extends AutomatedTestBase
@@ -183,46 +188,8 @@
@Test
public void TestGLMPrivateX(){
-
PrivacyConstraint pc = new PrivacyConstraint(PrivacyLevel.Private);
- Class<?> expectedException = null;
- switch ( glmType ){
- case Gaussianinverse:
- case Poissonlog1:
- case Poissonlog2:
- case Poissonsqrt:
- case Poissonid:
- case Gammalog:
- case Gammainverse:
- case InvGaussian1mu:
- case InvGaussianinverse:
- case InvGaussianlog:
- case InvGaussianid:
- case Binomialid:
- case Binomialcauchit:
- case Gaussianlog:
- case Gaussianid:
- case Bernoullilog:
- case Bernoulliid:
- case Bernoullisqrt:
- case Bernoullilogit1:
- case Bernoullilogit2:
- case Bernoulliprobit1:
- case Bernoulliprobit2:
- case Bernoullicloglog1:
- case Bernoullicloglog2:
- case Bernoullicauchit:
- case Binomiallog:
- case Binomialsqrt:
- case Binomiallogit:
- case Binomialprobit:
- case Binomialcloglog:
- expectedException = DMLException.class;
- break;
- default:
- expectedException = null;
- break;
- }
+ Class<?> expectedException = DMLException.class;
testGLM(pc, null, expectedException);
}
@@ -243,44 +210,7 @@
@Test
public void TestGLMPrivateY(){
PrivacyConstraint pc = new PrivacyConstraint(PrivacyLevel.Private);
- Class<?> expectedException = null;
- switch ( glmType ){
- case Gaussianinverse:
- case Poissonlog1:
- case Poissonlog2:
- case Poissonsqrt:
- case Poissonid:
- case Gammalog:
- case Gammainverse:
- case InvGaussian1mu:
- case InvGaussianinverse:
- case InvGaussianlog:
- case InvGaussianid:
- case Binomialid:
- case Binomialcauchit:
- case Gaussianlog:
- case Gaussianid:
- case Bernoullilog:
- case Bernoulliid:
- case Bernoullisqrt:
- case Bernoullilogit1:
- case Bernoullilogit2:
- case Bernoulliprobit1:
- case Bernoulliprobit2:
- case Bernoullicloglog1:
- case Bernoullicloglog2:
- case Bernoullicauchit:
- case Binomiallog:
- case Binomialsqrt:
- case Binomiallogit:
- case Binomialprobit:
- case Binomialcloglog:
- expectedException = DMLException.class;
- break;
- default:
- expectedException = null;
- break;
- }
+ Class<?> expectedException = DMLException.class;
testGLM(null, pc, expectedException);
}