[SYSTEMDS-2627] Federated mmchain instruction for lmCG, MLogreg, GLM
This patch adds a federated mmchain instruction for a common
matrix-vector multiplication chain as it appears in the inner loop of
lmCG, Mlogreg, and GLM. It also includes a fix for more robust
instruction manipulation, and a GLM federated test.
Furthermore, we now use a slightly better approach for deciding between
conf-only and context spark cluster analysis to avoid unnecessary spark
context creation in local mode (which sometimes interferes with netty
port allocation in federated tests).
diff --git a/src/main/java/org/apache/sysds/lops/MapMultChain.java b/src/main/java/org/apache/sysds/lops/MapMultChain.java
index 79d57f7..b45d813 100644
--- a/src/main/java/org/apache/sysds/lops/MapMultChain.java
+++ b/src/main/java/org/apache/sysds/lops/MapMultChain.java
@@ -35,7 +35,10 @@
XtXv, //(t(X) %*% (X %*% v))
XtwXv, //(t(X) %*% (w * (X %*% v)))
XtXvy, //(t(X) %*% ((X %*% v) - y))
- NONE,
+ NONE;
+ public boolean isWeighted() {
+ return this == XtwXv || this == ChainType.XtXvy;
+ }
}
private ChainType _chainType = null;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 65348f1..2be647d 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -1771,6 +1771,12 @@
_defaultPar = (defaultPar>1) ? defaultPar : numExecutors * numCoresPerExec;
_confOnly &= true;
}
+ else if( DMLScript.USE_LOCAL_SPARK_CONFIG ) {
+ //avoid unnecessary spark context creation in local mode (e.g., tests)
+ _numExecutors = 1;
+ _defaultPar = 2;
+ _confOnly &= true;
+ }
else {
//get default parallelism (total number of executors and cores)
//note: spark context provides this information while conf does not
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index c34fa62..429834b 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -52,10 +52,14 @@
//TODO better and safe replacement of operand names --> instruction utils
long id = getNextFedDataID();
String linst = inst.replace(ExecType.SPARK.name(), ExecType.CP.name());
- linst = linst.replace(Lop.OPERAND_DELIMITOR+varOldOut.getName(), Lop.OPERAND_DELIMITOR+String.valueOf(id));
+ linst = linst.replace(
+ Lop.OPERAND_DELIMITOR+varOldOut.getName()+Lop.DATATYPE_PREFIX,
+ Lop.OPERAND_DELIMITOR+String.valueOf(id)+Lop.DATATYPE_PREFIX);
for(int i=0; i<varOldIn.length; i++)
if( varOldIn[i] != null ) {
- linst = linst.replace(Lop.OPERAND_DELIMITOR+varOldIn[i].getName(), Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i]));
+ linst = linst.replace(
+ Lop.OPERAND_DELIMITOR+varOldIn[i].getName()+Lop.DATATYPE_PREFIX,
+ Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i])+Lop.DATATYPE_PREFIX);
linst = linst.replace("="+varOldIn[i].getName(), "="+String.valueOf(varNewIn[i])); //parameterized
}
return new FederatedRequest(RequestType.EXEC_INST, id, linst);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/MMChainCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/MMChainCPInstruction.java
index f540343..dcff65b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/MMChainCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/MMChainCPInstruction.java
@@ -36,31 +36,30 @@
_type = type;
_numThreads = k;
}
+
+ public ChainType getMMChainType() {
+ return _type;
+ }
public static MMChainCPInstruction parseInstruction ( String str ) {
//parse instruction parts (without exec type)
String[] parts = InstructionUtils.getInstructionPartsWithValueType( str );
InstructionUtils.checkNumFields( parts, 5, 6 );
-
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
- if( parts.length==6 )
- {
+ if( parts.length==6 ) {
CPOperand out= new CPOperand(parts[3]);
ChainType type = ChainType.valueOf(parts[4]);
int k = Integer.parseInt(parts[5]);
-
return new MMChainCPInstruction(null, in1, in2, null, out, type, k, opcode, str);
}
- else //parts.length==7
- {
+ else { //parts.length==7
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[4]);
ChainType type = ChainType.valueOf(parts[5]);
int k = Integer.parseInt(parts[6]);
-
return new MMChainCPInstruction(null, in1, in2, in3, out, type, k, opcode, str);
}
}
@@ -70,19 +69,15 @@
//get inputs
MatrixBlock X = ec.getMatrixInput(input1.getName());
MatrixBlock v = ec.getMatrixInput(input2.getName());
- MatrixBlock w = (_type==ChainType.XtwXv || _type==ChainType.XtXvy) ?
- ec.getMatrixInput(input3.getName()) : null;
+ MatrixBlock w = _type.isWeighted() ? ec.getMatrixInput(input3.getName()) : null;
+
//execute mmchain operation
- MatrixBlock out = X.chainMatrixMultOperations(v, w, new MatrixBlock(), _type, _numThreads);
+ MatrixBlock out = X.chainMatrixMultOperations(v, w, new MatrixBlock(), _type, _numThreads);
+
//set output and release inputs
ec.setMatrixOutput(output.getName(), out);
ec.releaseMatrixInput(input1.getName(), input2.getName());
if( w !=null )
ec.releaseMatrixInput(input3.getName());
}
-
- public ChainType getMMChainType()
- {
- return _type;
- }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 6df1b1e..77dedfd 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -35,6 +35,7 @@
MultiReturnParameterizedBuiltin,
ParameterizedBuiltin,
Tsmm,
+ MMChain,
}
protected final FEDType _fedType;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 4325456..bbdaa8e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -45,6 +45,18 @@
}
}
}
+ else if( inst instanceof MMChainCPInstruction) {
+ MMChainCPInstruction linst = (MMChainCPInstruction) inst;
+ MatrixObject mo = ec.getMatrixObject(linst.input1);
+ if( mo.isFederated() )
+ fedinst = MMChainFEDInstruction.parseInstruction(linst.getInstructionString());
+ }
+ else if( inst instanceof MMTSJCPInstruction ) {
+ MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
+ MatrixObject mo = ec.getMatrixObject(linst.input1);
+ if( mo.isFederated() )
+ fedinst = TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
+ }
else if (inst instanceof AggregateUnaryCPInstruction) {
AggregateUnaryCPInstruction instruction = (AggregateUnaryCPInstruction) inst;
if( instruction.input1.isMatrix() && ec.containsVariable(instruction.input1) ) {
@@ -77,12 +89,6 @@
}
}
}
- else if( inst instanceof MMTSJCPInstruction ) {
- MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
- MatrixObject mo = ec.getMatrixObject(linst.input1);
- if( mo.isFederated() )
- fedinst = TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
- }
//set thread id for federated context management
if( fedinst != null ) {
@@ -94,13 +100,14 @@
}
public static Instruction checkAndReplaceSP(Instruction inst, ExecutionContext ec) {
+ FEDInstruction fedinst = null;
if (inst instanceof MapmmSPInstruction) {
// FIXME does not yet work for MV multiplication. SPARK execution mode not supported for federated l2svm
MapmmSPInstruction instruction = (MapmmSPInstruction) inst;
Data data = ec.getVariable(instruction.input1);
if (data instanceof MatrixObject && ((MatrixObject) data).isFederated()) {
// TODO correct FED instruction string
- return new AggregateBinaryFEDInstruction(instruction.getOperator(),
+ fedinst = new AggregateBinaryFEDInstruction(instruction.getOperator(),
instruction.input1, instruction.input2, instruction.output, "ba+*", "FED...");
}
}
@@ -108,7 +115,7 @@
AggregateUnarySPInstruction instruction = (AggregateUnarySPInstruction) inst;
Data data = ec.getVariable(instruction.input1);
if (data instanceof MatrixObject && ((MatrixObject) data).isFederated())
- return AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst = AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
}
else if (inst instanceof WriteSPInstruction) {
WriteSPInstruction instruction = (WriteSPInstruction) inst;
@@ -124,9 +131,15 @@
AppendGAlignedSPInstruction instruction = (AppendGAlignedSPInstruction) inst;
Data data = ec.getVariable(instruction.input1);
if (data instanceof MatrixObject && ((MatrixObject) data).isFederated()) {
- return AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
+ fedinst = AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
}
}
+ //set thread id for federated context management
+ if( fedinst != null ) {
+ fedinst.setTID(ec.getTID());
+ return fedinst;
+ }
+
return inst;
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
new file mode 100644
index 0000000..2dee64b
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.lops.MapMultChain.ChainType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+import java.util.concurrent.Future;
+
+public class MMChainFEDInstruction extends UnaryFEDInstruction {
+
+ public MMChainFEDInstruction(CPOperand in1, CPOperand in2, CPOperand in3,
+ CPOperand out, ChainType type, int k, String opcode, String istr) {
+ super(FEDType.MMChain, null, in1, in2, in3, out, opcode, istr);
+ _type = type;
+ }
+
+ private final ChainType _type;
+
+ public ChainType getMMChainType() {
+ return _type;
+ }
+
+ public static MMChainFEDInstruction parseInstruction ( String str ) {
+ //parse instruction parts (without exec type)
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType( str );
+ InstructionUtils.checkNumFields( parts, 5, 6 );
+ String opcode = parts[0];
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+
+ if( parts.length==6 ) {
+ CPOperand out= new CPOperand(parts[3]);
+ ChainType type = ChainType.valueOf(parts[4]);
+ int k = Integer.parseInt(parts[5]);
+ return new MMChainFEDInstruction(in1, in2, null, out, type, k, opcode, str);
+ }
+ else { //parts.length==7
+ CPOperand in3 = new CPOperand(parts[3]);
+ CPOperand out = new CPOperand(parts[4]);
+ ChainType type = ChainType.valueOf(parts[5]);
+ int k = Integer.parseInt(parts[6]);
+ return new MMChainFEDInstruction(in1, in2, in3, out, type, k, opcode, str);
+ }
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ MatrixObject mo1 = ec.getMatrixObject(input1);
+ MatrixObject mo2 = ec.getMatrixObject(input2);
+ MatrixObject mo3 = _type.isWeighted() ? ec.getMatrixObject(input3) : null;
+
+ if( !mo1.isFederated() )
+ throw new DMLRuntimeException("Federated MMChain: Federated main input expected, "
+ + "but invoked w/ "+mo1.isFederated()+" "+mo2.isFederated());
+
+ if( !_type.isWeighted() ) { //XtXv
+ //construct commands: broadcast vector, execute, get and aggregate, cleanup
+ FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
+ FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+ else { //XtwXv | XtXvy
+ //construct commands: broadcast 2 vectors, execute, get and aggregate, cleanup
+ FederatedRequest[] fr0 = mo1.getFedMapping().broadcastSliced(mo3, false);
+ FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2, input3},
+ new long[]{mo1.getFedMapping().getID(), fr1.getID(), fr0[0].getID()});
+ FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr0, fr1, fr2, fr3);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ mo1.getFedMapping().cleanup(getTID(), fr0[0].getID(), fr1.getID(), fr2.getID());
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedGLMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedGLMTest.java
new file mode 100644
index 0000000..fe24bc8
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedGLMTest.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedGLMTest extends AutomatedTestBase {
+
+ private final static String TEST_DIR = "functions/federated/";
+ private final static String TEST_NAME = "FederatedGLMTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedGLMTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // rows have to be even and > 1
+ return Arrays.asList(new Object[][] {{10000, 10}, {1000, 100}, {2000, 43}});
+ }
+
+ @Test
+ public void federatedSinglenodeGLM() {
+ federatedGLM(Types.ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedHybridGLM() {
+ federatedGLM(Types.ExecMode.HYBRID);
+ }
+
+
+ public void federatedGLM(Types.ExecMode execMode) {
+ ExecMode platformOld = setExecMode(execMode);
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int halfRows = rows / 2;
+ // We have two matrices handled by a single federated worker
+ double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+ double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+ double[][] Y = getRandomMatrix(rows, 1, -1, 1, 1, 1233);
+ for(int i = 0; i < rows; i++)
+ Y[i][0] = (Y[i][0] > 0) ? 1 : -1;
+
+ writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("Y", Y, false, new MatrixCharacteristics(rows, 1, blocksize, rows));
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorker(port1);
+ Thread t2 = startLocalFedWorker(port2);
+
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+ setOutputBuffering(false);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-args", input("X1"), input("X2"), input("Y"), expected("Z")};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats",
+ "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols,
+ "in_Y=" + input("Y"), "out=" + output("Z")};
+ runTest(true, false, null, -1);
+
+ // compare via files
+ compareResults(1e-9);
+
+ TestUtils.shutdownThreads(t1, t2);
+
+ // check for federated operations
+ Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
+ Assert.assertTrue(heavyHittersContainsString("fed_uark+","fed_uarsqk+"));
+ Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
+ Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+ Assert.assertTrue(heavyHittersContainsString("fed_mmchain"));
+
+ //check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
+ resetExecMode(platformOld);
+ }
+}
diff --git a/src/test/scripts/functions/federated/FederatedGLMTest.dml b/src/test/scripts/functions/federated/FederatedGLMTest.dml
new file mode 100644
index 0000000..aa23b5e
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedGLMTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), list($rows, $cols)))
+Y = read($in_Y)
+
+model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
+write(model, $out)
diff --git a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml b/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
new file mode 100644
index 0000000..a307c8c
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2))
+Y = read($3)
+model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
+write(model, $4)