blob: bdc528d16057da26fb334691775014e5d927e487 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.controlprogram.parfor;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.util.LongAccumulator;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.parser.ParForStatementBlock;
import org.apache.sysds.parser.ParForStatementBlock.ResultVar;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.lineage.Lineage;
import org.apache.sysds.utils.Statistics;
import scala.Tuple2;
/**
* This class serves two purposes: (1) isolating Spark imports to enable running in
* environments where no Spark libraries are available, and (2) to follow the same
* structure as the parfor remote_mr job submission.
*
* NOTE: currently, we still exchange inputs and outputs via hdfs (this covers the general case
* if data already resides in HDFS, in-memory data, and partitioned inputs; also, it allows for
* pre-aggregation by overwriting partial task results with pre-paggregated results from subsequent
* iterations)
*
* TODO reducebykey on variable names
*/
public class RemoteParForSpark
{
protected static final Log LOG = LogFactory.getLog(RemoteParForSpark.class.getName());
//globally unique id for parfor spark job instances (unique across spark contexts)
private static final IDSequence _jobID = new IDSequence();
public static RemoteParForJobReturn runJob(long pfid, String prog, HashMap<String, byte[]> clsMap, List<Task> tasks,
ExecutionContext ec, ArrayList<ResultVar> resultVars, boolean cpCaching, int numMappers, boolean topLevelPF)
{
String jobname = "ParFor-ESP";
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
SparkExecutionContext sec = (SparkExecutionContext)ec;
JavaSparkContext sc = sec.getSparkContext();
//initialize accumulators for tasks/iterations
LongAccumulator aTasks = sc.sc().longAccumulator("tasks");
LongAccumulator aIters = sc.sc().longAccumulator("iterations");
//reset cached shared inputs for correctness in local mode
long jobid = _jobID.getNextID();
if( InfrastructureAnalyzer.isLocalMode() )
RemoteParForSparkWorker.cleanupCachedVariables(jobid);
// broadcast the inputs except the result variables
Map<String, Broadcast<CacheBlock>> brInputs = null;
if (ParForProgramBlock.ALLOW_BROADCAST_INPUTS) {
brInputs = broadcastInputs(sec, resultVars);
}
//prepare lineage
Map<String, String> serialLineage = DMLScript.LINEAGE ? ec.getLineage().serialize() : null;
//run remote_spark parfor job
//(w/o lazy evaluation to fit existing parfor framework, e.g., result merge)
List<Tuple2<Long, String>> out = sc.parallelize(tasks, tasks.size()) //create rdd of parfor tasks
.flatMapToPair(new RemoteParForSparkWorker(jobid, prog,
clsMap, cpCaching, aTasks, aIters, brInputs, topLevelPF, serialLineage))
.collect(); //execute and get output handles
//de-serialize results
LocalVariableMap[] results = RemoteParForUtils.getResults(out, LOG);
Lineage[] lineages = DMLScript.LINEAGE ?
RemoteParForUtils.getLineages(results) : null;
int numTasks = aTasks.value().intValue(); //get accumulator value
int numIters = aIters.value().intValue(); //get accumulator value
//create output symbol table entries
RemoteParForJobReturn ret = new RemoteParForJobReturn(true, numTasks, numIters, results, lineages);
//maintain statistics
Statistics.incrementNoOfCompiledSPInst();
Statistics.incrementNoOfExecutedSPInst();
if( DMLScript.STATISTICS )
Statistics.maintainCPHeavyHitters(jobname, System.nanoTime()-t0);
return ret;
}
@SuppressWarnings("unchecked")
private static Map<String, Broadcast<CacheBlock>> broadcastInputs(SparkExecutionContext sec, ArrayList<ParForStatementBlock.ResultVar> resultVars) {
LocalVariableMap inputs = sec.getVariables();
// exclude the result variables
// TODO use optimizer-picked list of amenable objects (e.g., size constraints)
Set<String> retVars = resultVars.stream()
.map(v -> v._name).collect(Collectors.toSet());
Set<String> brVars = inputs.keySet().stream()
.filter(v -> !retVars.contains(v)).collect(Collectors.toSet());
// construct broadcast objects
Map<String, Broadcast<CacheBlock>> result = new HashMap<>();
for (String key : brVars) {
Data var = sec.getVariable(key);
if ((var instanceof ScalarObject) || (var instanceof MatrixObject && ((MatrixObject) var).isPartitioned()))
continue;
result.put(key, sec.broadcastVariable((CacheableData<CacheBlock>) var));
}
return result;
}
}