blob: 2ad00166acbd7eb1e6885de4d8c89bc123019a3d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.twister2;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.tset.TBase;
import edu.iu.dsc.tws.api.tset.sets.TSet;
import edu.iu.dsc.tws.api.tset.sets.batch.BatchTSet;
import edu.iu.dsc.tws.tset.TBaseGraph;
import edu.iu.dsc.tws.tset.env.BatchTSetEnvironment;
import edu.iu.dsc.tws.tset.links.BaseTLink;
import edu.iu.dsc.tws.tset.sets.BaseTSet;
import edu.iu.dsc.tws.tset.sets.BuildableTSet;
import edu.iu.dsc.tws.tset.sets.batch.CachedTSet;
import edu.iu.dsc.tws.tset.sets.batch.ComputeTSet;
import edu.iu.dsc.tws.tset.sets.batch.SinkTSet;
import edu.iu.dsc.tws.tset.worker.BatchTSetIWorker;
import java.io.Serializable;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import org.apache.beam.runners.twister2.translators.functions.DoFnFunction;
import org.apache.beam.runners.twister2.translators.functions.Twister2SinkFunction;
/**
* The Twister2 worker that will execute the job logic once the job is submitted from the run
* method.
*/
@SuppressWarnings({
"rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
public class BeamBatchWorker implements Serializable, BatchTSetIWorker {
private static final String SIDEINPUTS = "sideInputs";
private static final String LEAVES = "leaves";
private static final String GRAPH = "graph";
private HashMap<String, BatchTSet<?>> sideInputDataSets;
private Set<TSet> leaves;
@Override
public void execute(BatchTSetEnvironment env) {
Config config = env.getConfig();
Map<String, String> sideInputIds = (LinkedHashMap<String, String>) config.get(SIDEINPUTS);
Set<String> leaveIds = (Set<String>) config.get(LEAVES);
TBaseGraph graph = (TBaseGraph) config.get(GRAPH);
env.settBaseGraph(graph);
setupTSets(env, sideInputIds, leaveIds);
resetEnv(env, graph);
executePipeline(env);
}
/**
* resets the proper TSetEnvironment in all the TSets in the graph that was sent over the wire.
*
* @param env the TSetEnvironment to be set
* @param graph the graph that contains all the TSets.
*/
private void resetEnv(BatchTSetEnvironment env, TBaseGraph graph) {
Set<TBase> nodes = graph.getNodes();
for (TBase node : nodes) {
if (node instanceof BaseTSet) {
((BaseTSet) node).setTSetEnv(env);
} else if (node instanceof BaseTLink) {
((BaseTLink) node).setTSetEnv(env);
} else {
throw new IllegalStateException("node must be either of type BaseTSet or BaseTLink");
}
}
}
/**
* Extract the sideInput TSets and the Leaves from the graph.
*
* @param env BatchTSetEnvironment which contains the tSet Graph
* @param sideInputIds the id's of the side Inputs
* @param leaveIds the id's of the leaf
*/
private void setupTSets(
BatchTSetEnvironment env, Map<String, String> sideInputIds, Set<String> leaveIds) {
sideInputDataSets = new LinkedHashMap<>();
leaves = new HashSet<>();
// reset sources, so that the graph does not have two source objects
// created during deserialization
Set<BuildableTSet> newSources = new HashSet<>();
for (BuildableTSet source : env.getGraph().getSources()) {
newSources.add((BuildableTSet) env.getGraph().getNodeById(source.getId()));
}
env.getGraph().setSources(newSources);
for (Map.Entry<String, String> entry : sideInputIds.entrySet()) {
BatchTSet curr = (BatchTSet) env.getGraph().getNodeById(entry.getValue());
sideInputDataSets.put(entry.getKey(), curr);
}
for (String leaveId : leaveIds) {
leaves.add((TSet) env.getGraph().getNodeById(leaveId));
}
}
public void executePipeline(BatchTSetEnvironment env) {
Map<String, CachedTSet> sideInputTSets = new HashMap<>();
for (Map.Entry<String, BatchTSet<?>> sides : sideInputDataSets.entrySet()) {
BatchTSet<?> sideTSet = sides.getValue();
addInputs((BaseTSet) sideTSet, sideInputTSets);
CachedTSet tempCache = (CachedTSet) sideTSet.cache();
sideInputTSets.put(sides.getKey(), tempCache);
}
for (TSet leaf : leaves) {
SinkTSet sinkTSet = (SinkTSet) leaf.direct().sink(new Twister2SinkFunction());
addInputs(sinkTSet, sideInputTSets);
eval(env, sinkTSet);
}
}
/** Adds all the side inputs into the sink test so it is available from the DoFn's. */
private void addInputs(BaseTSet sinkTSet, Map<String, CachedTSet> sideInputTSets) {
if (sideInputTSets.isEmpty()) {
return;
}
TBaseGraph graph = sinkTSet.getTBaseGraph();
TBase currNode = null;
Deque<TBase> deque = new ArrayDeque<>();
deque.add(sinkTSet);
while (!deque.isEmpty()) {
currNode = deque.remove();
deque.addAll(graph.getPredecessors(currNode));
if (currNode instanceof ComputeTSet) {
if (((ComputeTSet) currNode).getComputeFunc() instanceof DoFnFunction) {
Set<String> sideInputKeys =
((DoFnFunction) ((ComputeTSet) currNode).getComputeFunc()).getSideInputKeys();
for (String sideInputKey : sideInputKeys) {
if (!sideInputTSets.containsKey(sideInputKey)) {
throw new IllegalStateException("Side input not found for key " + sideInputKey);
}
((ComputeTSet) currNode).addInput(sideInputKey, sideInputTSets.get(sideInputKey));
}
}
}
}
}
public void eval(BatchTSetEnvironment env, SinkTSet<?> tSet) {
env.run(tSet);
}
}