blob: 8bd0e482e9edd53d354f1af943eb5ee9e6683b91 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.api.graph;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.FileSourceFunction;
import org.apache.flink.streaming.api.transformations.CoFeedbackTransformation;
import org.apache.flink.streaming.api.transformations.FeedbackTransformation;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.streaming.api.transformations.PartitionTransformation;
import org.apache.flink.streaming.api.transformations.SelectTransformation;
import org.apache.flink.streaming.api.transformations.SinkTransformation;
import org.apache.flink.streaming.api.transformations.SourceTransformation;
import org.apache.flink.streaming.api.transformations.SplitTransformation;
import org.apache.flink.streaming.api.transformations.StreamTransformation;
import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
import org.apache.flink.streaming.api.transformations.UnionTransformation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* A generator that generates a {@link StreamGraph} from a graph of
* {@link StreamTransformation StreamTransformations}.
*
* <p>
* This traverses the tree of {@code StreamTransformations} starting from the sinks. At each
* transformation we recursively transform the inputs, then create a node in the {@code StreamGraph}
* and add edges from the input Nodes to our newly created node. The transformation methods
* return the IDs of the nodes in the StreamGraph that represent the input transformation. Several
* IDs can be returned to be able to deal with feedback transformations and unions.
*
* <p>
* Partitioning, split/select and union don't create actual nodes in the {@code StreamGraph}. For
* these, we create a virtual node in the {@code StreamGraph} that holds the specific property, i.e.
* partitioning, selector and so on. When an edge is created from a virtual node to a downstream
* node the {@code StreamGraph} resolved the id of the original node and creates an edge
* in the graph with the desired property. For example, if you have this graph:
*
* <pre>
* Map-1 -&gt; HashPartition-2 -&gt; Map-3
* </pre>
*
* where the numbers represent transformation IDs. We first recurse all the way down. {@code Map-1}
* is transformed, i.e. we create a {@code StreamNode} with ID 1. Then we transform the
* {@code HashPartition}, for this, we create virtual node of ID 4 that holds the property
* {@code HashPartition}. This transformation returns the ID 4. Then we transform the {@code Map-3}.
* We add the edge {@code 4 -> 3}. The {@code StreamGraph} resolved the actual node with ID 1 and
* creates and edge {@code 1 -> 3} with the property HashPartition.
*/
public class StreamGraphGenerator {
private static final Logger LOG = LoggerFactory.getLogger(StreamGraphGenerator.class);
// The StreamGraph that is being built, this is initialized at the beginning.
private StreamGraph streamGraph;
private final StreamExecutionEnvironment env;
// This is used to assign a unique ID to iteration source/sink
protected static Integer iterationIdCounter = 0;
public static int getNewIterationNodeId() {
iterationIdCounter--;
return iterationIdCounter;
}
// Keep track of which Transforms we have already transformed, this is necessary because
// we have loops, i.e. feedback edges.
private Map<StreamTransformation<?>, Collection<Integer>> alreadyTransformed;
/**
* Private constructor. The generator should only be invoked using {@link #generate}.
*/
private StreamGraphGenerator(StreamExecutionEnvironment env) {
this.streamGraph = new StreamGraph(env);
this.streamGraph.setChaining(env.isChainingEnabled());
if (env.getCheckpointInterval() > 0) {
this.streamGraph.setCheckpointingEnabled(true);
this.streamGraph.setCheckpointingInterval(env.getCheckpointInterval());
this.streamGraph.setCheckpointingMode(env.getCheckpointingMode());
}
this.streamGraph.setStateBackend(env.getStateBackend());
if (env.isForceCheckpointing()) {
this.streamGraph.forceCheckpoint();
}
this.env = env;
this.alreadyTransformed = new HashMap<>();
}
/**
* Generates a {@code StreamGraph} by traversing the graph of {@code StreamTransformations}
* starting from the given transformations.
*
* @param env The {@code StreamExecutionEnvironment} that is used to set some parameters of the
* job
* @param transformations The transformations starting from which to transform the graph
*
* @return The generated {@code StreamGraph}
*/
public static StreamGraph generate(StreamExecutionEnvironment env, List<StreamTransformation<?>> transformations) {
return new StreamGraphGenerator(env).generateInternal(transformations);
}
/**
* This starts the actual transformation, beginning from the sinks.
*/
private StreamGraph generateInternal(List<StreamTransformation<?>> transformations) {
for (StreamTransformation<?> transformation: transformations) {
transform(transformation);
}
return streamGraph;
}
/**
* Transforms one {@code StreamTransformation}.
*
* <p>
* This checks whether we already transformed it and exits early in that case. If not it
* delegates to one of the transformation specific methods.
*/
private Collection<Integer> transform(StreamTransformation<?> transform) {
if (alreadyTransformed.containsKey(transform)) {
return alreadyTransformed.get(transform);
}
LOG.debug("Transforming " + transform);
// call at least once to trigger exceptions about MissingTypeInfo
transform.getOutputType();
Collection<Integer> transformedIds;
if (transform instanceof OneInputTransformation<?, ?>) {
transformedIds = transformOnInputTransform((OneInputTransformation<?, ?>) transform);
} else if (transform instanceof TwoInputTransformation<?, ?, ?>) {
transformedIds = transformTwoInputTransform((TwoInputTransformation<?, ?, ?>) transform);
} else if (transform instanceof SourceTransformation<?>) {
transformedIds = transformSource((SourceTransformation<?>) transform);
} else if (transform instanceof SinkTransformation<?>) {
transformedIds = transformSink((SinkTransformation<?>) transform);
} else if (transform instanceof UnionTransformation<?>) {
transformedIds = transformUnion((UnionTransformation<?>) transform);
} else if (transform instanceof SplitTransformation<?>) {
transformedIds = transformSplit((SplitTransformation<?>) transform);
} else if (transform instanceof SelectTransformation<?>) {
transformedIds = transformSelect((SelectTransformation<?>) transform);
} else if (transform instanceof FeedbackTransformation<?>) {
transformedIds = transformFeedback((FeedbackTransformation<?>) transform);
} else if (transform instanceof CoFeedbackTransformation<?>) {
transformedIds = transformCoFeedback((CoFeedbackTransformation<?>) transform);
} else if (transform instanceof PartitionTransformation<?>) {
transformedIds = transformPartition((PartitionTransformation<?>) transform);
} else {
throw new IllegalStateException("Unknown transformation: " + transform);
}
// need this check because the iterate transformation adds itself before
// transforming the feedback edges
if (!alreadyTransformed.containsKey(transform)) {
alreadyTransformed.put(transform, transformedIds);
}
if (transform.getBufferTimeout() > 0) {
streamGraph.setBufferTimeout(transform.getId(), transform.getBufferTimeout());
}
if (transform.getResourceStrategy() != StreamGraph.ResourceStrategy.DEFAULT) {
streamGraph.setResourceStrategy(transform.getId(), transform.getResourceStrategy());
}
return transformedIds;
}
/**
* Transforms a {@code UnionTransformation}.
*
* <p>
* This is easy, we only have to transform the inputs and return all the IDs in a list so
* that downstream operations can connect to all upstream nodes.
*/
private <T> Collection<Integer> transformUnion(UnionTransformation<T> union) {
List<StreamTransformation<T>> inputs = union.getInputs();
List<Integer> resultIds = new ArrayList<>();
for (StreamTransformation<T> input: inputs) {
resultIds.addAll(transform(input));
}
return resultIds;
}
/**
* Transforms a {@code PartitionTransformation}.
*
* <p>
* For this we create a virtual node in the {@code StreamGraph} that holds the partition
* property. @see StreamGraphGenerator
*/
private <T> Collection<Integer> transformPartition(PartitionTransformation<T> partition) {
StreamTransformation<T> input = partition.getInput();
List<Integer> resultIds = new ArrayList<>();
Collection<Integer> transformedIds = transform(input);
for (Integer transformedId: transformedIds) {
int virtualId = StreamTransformation.getNewNodeId();
streamGraph.addVirtualPartitionNode(transformedId, virtualId, partition.getPartitioner());
resultIds.add(virtualId);
}
return resultIds;
}
/**
* Transforms a {@code SplitTransformation}.
*
* <p>
* We add the output selector to previously transformed nodes.
*/
private <T> Collection<Integer> transformSplit(SplitTransformation<T> split) {
StreamTransformation<T> input = split.getInput();
Collection<Integer> resultIds = transform(input);
// the recursive transform call might have transformed this already
if (alreadyTransformed.containsKey(split)) {
return alreadyTransformed.get(split);
}
for (int inputId : resultIds) {
streamGraph.addOutputSelector(inputId, split.getOutputSelector());
}
return resultIds;
}
/**
* Transforms a {@code SelectTransformation}.
*
* <p>
* For this we create a virtual node in the {@code StreamGraph} holds the selected names.
* @see org.apache.flink.streaming.api.graph.StreamGraphGenerator
*/
private <T> Collection<Integer> transformSelect(SelectTransformation<T> select) {
StreamTransformation<T> input = select.getInput();
Collection<Integer> resultIds = transform(input);
// the recursive transform might have already transformed this
if (alreadyTransformed.containsKey(select)) {
return alreadyTransformed.get(select);
}
List<Integer> virtualResultIds = new ArrayList<>();
for (int inputId : resultIds) {
int virtualId = StreamTransformation.getNewNodeId();
streamGraph.addVirtualSelectNode(inputId, virtualId, select.getSelectedNames());
virtualResultIds.add(virtualId);
}
return virtualResultIds;
}
/**
* Transforms a {@code FeedbackTransformation}.
*
* <p>
* This will recursively transform the input and the feedback edges. We return the concatenation
* of the input IDs and the feedback IDs so that downstream operations can be wired to both.
*
* <p>
* This is responsible for creating the IterationSource and IterationSink which
* are used to feed back the elements.
*/
private <T> Collection<Integer> transformFeedback(FeedbackTransformation<T> iterate) {
if (iterate.getFeedbackEdges().size() <= 0) {
throw new IllegalStateException("Iteration " + iterate + " does not have any feedback edges.");
}
StreamTransformation<T> input = iterate.getInput();
List<Integer> resultIds = new ArrayList<>();
// first transform the input stream(s) and store the result IDs
resultIds.addAll(transform(input));
// the recursive transform might have already transformed this
if (alreadyTransformed.containsKey(iterate)) {
return alreadyTransformed.get(iterate);
}
// create the fake iteration source/sink pair
Tuple2<StreamNode, StreamNode> itSourceAndSink = streamGraph.createIterationSourceAndSink(
iterate.getId(),
getNewIterationNodeId(),
getNewIterationNodeId(),
iterate.getWaitTime(),
iterate.getParallelism());
StreamNode itSource = itSourceAndSink.f0;
StreamNode itSink = itSourceAndSink.f1;
// We set the proper serializers for the sink/source
streamGraph.setSerializers(itSource.getId(), null, null, iterate.getOutputType().createSerializer(env.getConfig()));
streamGraph.setSerializers(itSink.getId(), iterate.getOutputType().createSerializer(env.getConfig()), null, null);
// also add the feedback source ID to the result IDs, so that downstream operators will
// add both as input
resultIds.add(itSource.getId());
// at the iterate to the already-seen-set with the result IDs, so that we can transform
// the feedback edges and let them stop when encountering the iterate node
alreadyTransformed.put(iterate, resultIds);
for (StreamTransformation<T> feedbackEdge : iterate.getFeedbackEdges()) {
Collection<Integer> feedbackIds = transform(feedbackEdge);
for (Integer feedbackId: feedbackIds) {
streamGraph.addEdge(feedbackId,
itSink.getId(),
0
);
}
}
return resultIds;
}
/**
* Transforms a {@code CoFeedbackTransformation}.
*
* <p>
* This will only transform feedback edges, the result of this transform will be wired
* to the second input of a Co-Transform. The original input is wired directly to the first
* input of the downstream Co-Transform.
*
* <p>
* This is responsible for creating the IterationSource and IterationSink which
* are used to feed back the elements.
*/
private <F> Collection<Integer> transformCoFeedback(CoFeedbackTransformation<F> coIterate) {
// For Co-Iteration we don't need to transform the input and wire the input to the
// head operator by returning the input IDs, the input is directly wired to the left
// input of the co-operation. This transform only needs to return the ids of the feedback
// edges, since they need to be wired to the second input of the co-operation.
// create the fake iteration source/sink pair
Tuple2<StreamNode, StreamNode> itSourceAndSink = streamGraph.createIterationSourceAndSink(
coIterate.getId(),
getNewIterationNodeId(),
getNewIterationNodeId(),
coIterate.getWaitTime(),
coIterate.getParallelism());
StreamNode itSource = itSourceAndSink.f0;
StreamNode itSink = itSourceAndSink.f1;
// We set the proper serializers for the sink/source
streamGraph.setSerializers(itSource.getId(), null, null, coIterate.getOutputType().createSerializer(env.getConfig()));
streamGraph.setSerializers(itSink.getId(), coIterate.getOutputType().createSerializer(env.getConfig()), null, null);
Collection<Integer> resultIds = Collections.singleton(itSource.getId());
// at the iterate to the already-seen-set with the result IDs, so that we can transform
// the feedback edges and let them stop when encountering the iterate node
alreadyTransformed.put(coIterate, resultIds);
for (StreamTransformation<F> feedbackEdge : coIterate.getFeedbackEdges()) {
Collection<Integer> feedbackIds = transform(feedbackEdge);
for (Integer feedbackId: feedbackIds) {
streamGraph.addEdge(feedbackId,
itSink.getId(),
0
);
}
}
return Collections.singleton(itSource.getId());
}
/**
* Transforms a {@code SourceTransformation}.
*/
private <T> Collection<Integer> transformSource(SourceTransformation<T> source) {
streamGraph.addSource(source.getId(),
source.getOperator(),
null,
source.getOutputType(),
"Source: " + source.getName());
if (source.getOperator().getUserFunction() instanceof FileSourceFunction) {
FileSourceFunction<T> fs = (FileSourceFunction<T>) source.getOperator().getUserFunction();
streamGraph.setInputFormat(source.getId(), fs.getFormat());
}
streamGraph.setParallelism(source.getId(), source.getParallelism());
return Collections.singleton(source.getId());
}
/**
* Transforms a {@code SourceTransformation}.
*/
private <T> Collection<Integer> transformSink(SinkTransformation<T> sink) {
Collection<Integer> inputIds = transform(sink.getInput());
streamGraph.addSink(sink.getId(),
sink.getOperator(),
sink.getInput().getOutputType(),
null,
"Sink: " + sink.getName());
streamGraph.setParallelism(sink.getId(), sink.getParallelism());
for (Integer inputId: inputIds) {
streamGraph.addEdge(inputId,
sink.getId(),
0
);
}
if (sink.getStateKeySelector() != null) {
TypeSerializer<?> keySerializer = sink.getStateKeyType().createSerializer(env.getConfig());
streamGraph.setKey(sink.getId(), sink.getStateKeySelector(), keySerializer);
}
return Collections.emptyList();
}
/**
* Transforms a {@code OneInputTransformation}.
*
* <p>
* This recusively transforms the inputs, creates a new {@code StreamNode} in the graph and
* wired the inputs to this new node.
*/
private <IN, OUT> Collection<Integer> transformOnInputTransform(OneInputTransformation<IN, OUT> transform) {
Collection<Integer> inputIds = transform(transform.getInput());
// the recursive call might have already transformed this
if (alreadyTransformed.containsKey(transform)) {
return alreadyTransformed.get(transform);
}
streamGraph.addOperator(transform.getId(),
transform.getOperator(),
transform.getInputType(),
transform.getOutputType(),
transform.getName());
if (transform.getStateKeySelector() != null) {
TypeSerializer<?> keySerializer = transform.getStateKeyType().createSerializer(env.getConfig());
streamGraph.setKey(transform.getId(), transform.getStateKeySelector(), keySerializer);
}
if (transform.getStateKeyType() != null) {
}
streamGraph.setParallelism(transform.getId(), transform.getParallelism());
for (Integer inputId: inputIds) {
streamGraph.addEdge(inputId, transform.getId(), 0);
}
return Collections.singleton(transform.getId());
}
/**
* Transforms a {@code TwoInputTransformation}.
*
* <p>
* This recusively transforms the inputs, creates a new {@code StreamNode} in the graph and
* wired the inputs to this new node.
*/
private <IN1, IN2, OUT> Collection<Integer> transformTwoInputTransform(TwoInputTransformation<IN1, IN2, OUT> transform) {
Collection<Integer> inputIds1 = transform(transform.getInput1());
Collection<Integer> inputIds2 = transform(transform.getInput2());
// the recursive call might have already transformed this
if (alreadyTransformed.containsKey(transform)) {
return alreadyTransformed.get(transform);
}
streamGraph.addCoOperator(
transform.getId(),
transform.getOperator(),
transform.getInputType1(),
transform.getInputType2(),
transform.getOutputType(),
transform.getName());
streamGraph.setParallelism(transform.getId(), transform.getParallelism());
for (Integer inputId: inputIds1) {
streamGraph.addEdge(inputId,
transform.getId(),
1
);
}
for (Integer inputId: inputIds2) {
streamGraph.addEdge(inputId,
transform.getId(),
2
);
}
return Collections.singleton(transform.getId());
}
}