blob: 8d98436e257643d0cfb9201550f9ef00dd67ee80 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.test.functions.jmlc;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.api.DMLException;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.api.jmlc.Connection;
import org.apache.sysds.api.jmlc.PreparedScript;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.utils.Statistics;
public class JMLCClonedPreparedScriptTest extends AutomatedTestBase
{
//basic script with parfor loop
private static final String SCRIPT1 =
"X = matrix(7, 10, 10);"
+ "R = matrix(0, 10, 1)"
+ "parfor(i in 1:nrow(X))"
+ " R[i,] = sum(X[i,])"
+ "out = sum(R)"
+ "write(out, 'tmp/out')";
//script with dml-bodied and external functions
private static final String SCRIPT2 =
"foo = function(Matrix[double] A, Matrix[double] B, Matrix[double] C)"
+ " return (Matrix[double] D) {"
+ " while(FALSE){}"
+ " D = cbind(A, B, C)"
+ "}"
+ "X = matrix(7, 10, 10);"
+ "R = matrix(0, 10, 1)"
+ "for(i in 1:nrow(X)) {"
+ " E = foo(X[1,], X[2,], X[3,])"
+ " R[i,] = sum(E)/3"
+ "}"
+ "out = sum(R)"
+ "write(out, 'tmp/out')";
@Override
public void setUp() {
//do nothing
}
@Test
public void testSinglePreparedScript1T128() {
runJMLCClonedTest(SCRIPT1, 128, false);
}
@Test
public void testClonedPreparedScript1T128() {
runJMLCClonedTest(SCRIPT1, 128, true);
}
@Test
public void testSinglePreparedScript2T128() {
runJMLCClonedTest(SCRIPT2, 128, false);
}
@Test
public void testClonedPreparedScript2T128() {
runJMLCClonedTest(SCRIPT2, 128, true);
}
private static void runJMLCClonedTest(String script, int num, boolean clone) {
int k = InfrastructureAnalyzer.getLocalParallelism();
boolean failed = false;
try( Connection conn = new Connection() ) {
DMLScript.STATISTICS = true;
Statistics.reset();
PreparedScript pscript = conn.prepareScript(
script, new String[]{}, new String[]{"out"});
ExecutorService pool = Executors.newFixedThreadPool(k);
ArrayList<JMLCTask> tasks = new ArrayList<>();
for(int i=0; i<num; i++)
tasks.add(new JMLCTask(pscript, clone));
List<Future<Double>> taskrets = pool.invokeAll(tasks);
for(Future<Double> ret : taskrets)
if( ret.get() != 700 )
throw new RuntimeException("wrong results: "+ret.get());
pool.shutdown();
}
catch(Exception ex) {
ex.printStackTrace();
failed = true;
}
//check expected failure
Assert.assertTrue(failed==!clone || k==1);
}
private static class JMLCTask implements Callable<Double>
{
private final PreparedScript _pscript;
private final boolean _clone;
protected JMLCTask(PreparedScript pscript, boolean clone) {
_pscript = pscript;
_clone = clone;
}
@Override
public Double call() throws DMLException
{
if( _clone )
return _pscript.clone(false).executeScript().getDouble("out");
else
return _pscript.executeScript().getDouble("out");
}
}
}