blob: d3b985c6574e03b61b035521d562db0266379357 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.iteration;
import org.apache.flink.annotation.Experimental;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.iteration.compile.DraftExecutionEnvironment;
import org.apache.flink.iteration.operator.HeadOperator;
import org.apache.flink.iteration.operator.HeadOperatorFactory;
import org.apache.flink.iteration.operator.InputOperator;
import org.apache.flink.iteration.operator.OperatorWrapper;
import org.apache.flink.iteration.operator.OutputOperator;
import org.apache.flink.iteration.operator.ReplayOperator;
import org.apache.flink.iteration.operator.TailOperator;
import org.apache.flink.iteration.operator.allround.AllRoundOperatorWrapper;
import org.apache.flink.iteration.operator.perround.PerRoundOperatorWrapper;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
import org.apache.flink.iteration.utils.DataStreamUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.streaming.api.transformations.PhysicalTransformation;
import org.apache.flink.util.Collector;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.apache.flink.util.Preconditions.checkState;
/**
* A helper class to create iterations. To construct an iteration, Users are required to provide
*
* <ul>
* <li>initVariableStreams: the initial values of the variable data streams which would be updated
* in each round.
* <li>dataStreams: the other data streams used inside the iteration, but would not be updated.
* <li>iterationBody: specifies the subgraph to update the variable streams and the outputs.
* </ul>
*
* <p>The iteration body will be invoked with two parameters: The first parameter is a list of input
* variable streams, which are created as the union of the initial variable streams and the
* corresponding feedback variable streams (returned by the iteration body); The second parameter is
* the data streams given to this method.
*
* <p>During the execution of iteration body, each of the records involved in the iteration has an
* epoch attached, which is mark the progress of the iteration. The epoch is computed as:
*
* <ul>
* <li>All records in the initial variable streams and initial data streams has epoch = 0.
* <li>For any record emitted by this operator into a non-feedback stream, the epoch of this
* emitted record = the epoch of the input record that triggers this emission. If this record
* is emitted by onEpochWatermarkIncremented(), then the epoch of this record =
* epochWatermark.
* <li>For any record emitted by this operator into a feedback variable stream, the epoch of the
* emitted record = the epoch of the input record that triggers this emission + 1.
* </ul>
*
* <p>The framework would given the notification at the end of each epoch for operators and UDFs
* that implements {@link IterationListener}.
*
* <p>The limitation of constructing the subgraph inside the iteration body could be refer in {@link
* IterationBody}.
*
* <p>Note that the iteration framework cannot deal with watermarks correctly for now. It should be
* resolved by FLINK-31373.
*
* <p>An example of the iteration is like:
*
* <pre>{@code
* DataStreamList result = Iterations.iterateUnboundedStreams(
* DataStreamList.of(first, second),
* DataStreamList.of(third),
* (variableStreams, dataStreams) -> {
* ...
* return new IterationBodyResult(
* DataStreamList.of(firstFeedback, secondFeedback),
* DataStreamList.of(output));
* }
* result.<Integer>get(0).addSink(...);
* }</pre>
*/
@Experimental
public class Iterations {
/**
* This method uses an iteration body to process records in possibly unbounded data streams. The
* iteration would not terminate if at least one of its inputs is unbounded. Otherwise it will
* terminated after all the inputs are terminated and no more records are iterating.
*
* @param initVariableStreams The initial variable streams, which is merged with the feedback
* variable streams before being used as the 1st parameter to invoke the iteration body.
* @param dataStreams The non-variable streams also refer in the {@code body}.
* @param body The computation logic which takes variable/data streams and returns
* feedback/output streams.
* @return The list of output streams returned by the iteration boy.
*/
public static DataStreamList iterateUnboundedStreams(
DataStreamList initVariableStreams, DataStreamList dataStreams, IterationBody body) {
return createIteration(
initVariableStreams,
dataStreams,
Collections.emptySet(),
body,
new AllRoundOperatorWrapper(),
false);
}
/**
* This method uses an iteration body to process records in some bounded data streams
* iteratively until no more records are iterating or the given terminating criteria stream is
* empty in one round.
*
* @param initVariableStreams The initial variable streams, which is merged with the feedback
* variable streams before being used as the 1st parameter to invoke the iteration body.
* @param dataStreams The non-variable streams also refer in the {@code body} and if each of
* them needs replayed for each round.
* @param config The config for the iteration, like whether to re-create the operator on each
* round.
* @param body The computation logic which takes variable/data streams and returns
* feedback/output streams.
* @return The list of output streams returned by the iteration boy.
*/
public static DataStreamList iterateBoundedStreamsUntilTermination(
DataStreamList initVariableStreams,
ReplayableDataStreamList dataStreams,
IterationConfig config,
IterationBody body) {
OperatorWrapper wrapper =
config.getOperatorLifeCycle() == IterationConfig.OperatorLifeCycle.ALL_ROUND
? new AllRoundOperatorWrapper<>()
: new PerRoundOperatorWrapper<>();
List<DataStream<?>> allDatastreams = new ArrayList<>();
allDatastreams.addAll(dataStreams.getReplayedDataStreams());
allDatastreams.addAll(dataStreams.getNonReplayedStreams());
Set<Integer> replayedIndices =
IntStream.range(0, dataStreams.getReplayedDataStreams().size())
.boxed()
.collect(Collectors.toSet());
return createIteration(
initVariableStreams,
new DataStreamList(allDatastreams),
replayedIndices,
body,
wrapper,
true);
}
@SuppressWarnings({"unchecked", "rawtypes"})
private static DataStreamList createIteration(
DataStreamList initVariableStreams,
DataStreamList dataStreams,
Set<Integer> replayedDataStreamIndices,
IterationBody body,
OperatorWrapper<?, IterationRecord<?>> initialOperatorWrapper,
boolean mayHaveCriteria) {
checkState(initVariableStreams.size() > 0, "There should be at least one variable stream");
IterationID iterationId = new IterationID();
List<TypeInformation<?>> initVariableTypeInfos = getTypeInfos(initVariableStreams);
List<TypeInformation<?>> dataStreamTypeInfos = getTypeInfos(dataStreams);
// Add heads and inputs
int totalInitVariableParallelism =
map(
initVariableStreams,
dataStream ->
dataStream.getParallelism() > 0
? dataStream.getParallelism()
: dataStream
.getExecutionEnvironment()
.getConfig()
.getParallelism())
.stream()
.mapToInt(i -> i)
.sum();
DataStreamList initVariableInputs = addInputs(initVariableStreams);
DataStreamList headStreams =
addHeads(
initVariableStreams,
initVariableInputs,
iterationId,
totalInitVariableParallelism,
false,
0);
DataStreamList dataStreamInputs = addInputs(dataStreams);
if (replayedDataStreamIndices.size() > 0) {
dataStreamInputs =
addReplayer(
headStreams.get(0),
dataStreams,
dataStreamInputs,
replayedDataStreamIndices);
}
// Creates the iteration body. We map the inputs of iteration body into the draft sources,
// which serve as the start points to build the draft subgraph.
StreamExecutionEnvironment env = initVariableStreams.get(0).getExecutionEnvironment();
DraftExecutionEnvironment draftEnv =
new DraftExecutionEnvironment(env, initialOperatorWrapper);
DataStreamList draftHeadStreams =
addDraftSources(headStreams, draftEnv, initVariableTypeInfos);
DataStreamList draftDataStreamInputs =
addDraftSources(dataStreamInputs, draftEnv, dataStreamTypeInfos);
IterationBodyResult iterationBodyResult =
body.process(draftHeadStreams, draftDataStreamInputs);
ensuresTransformationAdded(iterationBodyResult.getFeedbackVariableStreams(), draftEnv);
ensuresTransformationAdded(iterationBodyResult.getOutputStreams(), draftEnv);
draftEnv.copyToActualEnvironment();
// Adds tails and co-locate them with the heads.
DataStreamList feedbackStreams =
getActualDataStreams(iterationBodyResult.getFeedbackVariableStreams(), draftEnv);
checkState(
feedbackStreams.size() == initVariableStreams.size(),
"The number of feedback streams "
+ feedbackStreams.size()
+ " does not match the initialized one "
+ initVariableStreams.size());
for (int i = 0; i < feedbackStreams.size(); ++i) {
checkState(
feedbackStreams.get(i).getParallelism() == headStreams.get(i).getParallelism(),
String.format(
"The feedback stream %d have different parallelism %d with the initial stream, which is %d",
i,
feedbackStreams.get(i).getParallelism(),
headStreams.get(i).getParallelism()));
}
DataStreamList tails = addTails(feedbackStreams, iterationId, 0);
for (int i = 0; i < headStreams.size(); ++i) {
String coLocationGroupKey = "co-" + iterationId.toHexString() + "-" + i;
headStreams.get(i).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
tails.get(i).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
}
List<DataStream<?>> tailsAndCriteriaTails = new ArrayList<>(tails.getDataStreams());
checkState(
mayHaveCriteria || iterationBodyResult.getTerminationCriteria() == null,
"The current iteration type does not support the termination criteria.");
if (iterationBodyResult.getTerminationCriteria() != null) {
DataStreamList criteriaTails =
addCriteriaStream(
iterationBodyResult.getTerminationCriteria(),
iterationId,
env,
draftEnv,
initVariableStreams,
headStreams,
totalInitVariableParallelism);
tailsAndCriteriaTails.addAll(criteriaTails.getDataStreams());
}
DataStream<Integer> tailsUnion =
unionAllTails(env, new DataStreamList(tailsAndCriteriaTails));
return addOutputs(
getActualDataStreams(iterationBodyResult.getOutputStreams(), draftEnv), tailsUnion);
}
private static DataStreamList addReplayer(
DataStream<?> firstHeadStream,
DataStreamList originalDataStreams,
DataStreamList dataStreamInputs,
Set<Integer> replayedDataStreamIndices) {
List<DataStream<?>> result = new ArrayList<>(dataStreamInputs.size());
for (int i = 0; i < dataStreamInputs.size(); ++i) {
if (!replayedDataStreamIndices.contains(i)) {
result.add(dataStreamInputs.get(i));
continue;
}
// Notes that the HeadOperator would broadcast the globally aligned events,
// thus the operator does not require emit to the sideoutput specially.
DataStream<?> replayedInput =
dataStreamInputs
.get(i)
.connect(
((SingleOutputStreamOperator<IterationRecord<?>>)
firstHeadStream)
.getSideOutput(HeadOperator.ALIGN_NOTIFY_OUTPUT_TAG)
.broadcast())
.transform(
"Replayer-"
+ originalDataStreams
.get(i)
.getTransformation()
.getName(),
dataStreamInputs.get(i).getType(),
(TwoInputStreamOperator) new ReplayOperator<>())
.setParallelism(dataStreamInputs.get(i).getParallelism());
result.add(replayedInput);
}
return new DataStreamList(result);
}
private static DataStreamList addCriteriaStream(
DataStream<?> draftCriteriaStream,
IterationID iterationId,
StreamExecutionEnvironment env,
DraftExecutionEnvironment draftEnv,
DataStreamList initVariableStreams,
DataStreamList headStreams,
int totalInitVariableParallelism) {
// Deals with the criteria streams
DataStream<?> terminationCriteria = draftEnv.getActualStream(draftCriteriaStream.getId());
// It should always has the IterationRecordTypeInfo
checkState(
terminationCriteria.getType().getClass().equals(IterationRecordTypeInfo.class),
"The termination criteria should always return IterationRecord.");
TypeInformation<?> innerType =
((IterationRecordTypeInfo<?>) terminationCriteria.getType()).getInnerTypeInfo();
DataStream<?> emptyCriteriaSource =
env.addSource(new DraftExecutionEnvironment.EmptySource())
.returns(innerType)
.name(terminationCriteria.getTransformation().getName())
.setParallelism(terminationCriteria.getParallelism());
DataStreamList criteriaSources = DataStreamList.of(emptyCriteriaSource);
DataStreamList criteriaInputs = addInputs(criteriaSources);
DataStreamList criteriaHeaders =
addHeads(
criteriaSources,
criteriaInputs,
iterationId,
totalInitVariableParallelism,
true,
initVariableStreams.size());
// Merges the head and the actual criteria stream. This is required since if we have
// no edges from the criteria head to the criteria tail, the tail might directly received
// the MAX_EPOCH_WATERMARK without the synchronization of the head.
DataStream<?> mergedHeadAndCriteria =
mergeCriteriaHeadAndCriteriaStream(
env, criteriaHeaders.get(0), terminationCriteria, innerType);
DataStreamList criteriaTails =
addTails(
DataStreamList.of(mergedHeadAndCriteria),
iterationId,
initVariableStreams.size());
String coLocationGroupKey = "co-" + iterationId.toHexString() + "-cri";
criteriaHeaders.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
criteriaTails.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
// Now we notify all the head operators to count the criteria streams.
setCriteriaParallelism(headStreams, terminationCriteria.getParallelism());
setCriteriaParallelism(criteriaHeaders, terminationCriteria.getParallelism());
return criteriaTails;
}
@SuppressWarnings({"unchecked", "rawtypes"})
private static DataStream<?> mergeCriteriaHeadAndCriteriaStream(
StreamExecutionEnvironment env,
DataStream<?> head,
DataStream<?> criteriaStream,
TypeInformation<?> criteriaStreamType) {
DraftExecutionEnvironment criteriaDraftEnv =
new DraftExecutionEnvironment(env, new AllRoundOperatorWrapper<>());
DataStream draftHeadStream = criteriaDraftEnv.addDraftSource(head, criteriaStreamType);
DataStream draftTerminationCriteria =
criteriaDraftEnv.addDraftSource(criteriaStream, criteriaStreamType);
DataStream draftMergedStream =
draftHeadStream
.connect(draftTerminationCriteria)
.process(new CriteriaMergeProcessor())
.returns(criteriaStreamType)
.setParallelism(
criteriaStream.getParallelism() > 0
? criteriaStream.getParallelism()
: env.getConfig().getParallelism())
.name("criteria-merge");
criteriaDraftEnv.copyToActualEnvironment();
return criteriaDraftEnv.getActualStream(draftMergedStream.getId());
}
@SuppressWarnings({"unchecked", "rawtypes"})
private static DataStream<Integer> unionAllTails(
StreamExecutionEnvironment env, DataStreamList tailsAndCriteriaTails) {
return Iterations.<DataStream>map(
tailsAndCriteriaTails,
tail ->
tail.filter(r -> false)
.name("filter-tail")
.returns((TypeInformation) Types.INT)
.setParallelism(
tail.getParallelism() > 0
? tail.getParallelism()
: env.getConfig().getParallelism()))
.stream()
.reduce(DataStream::union)
.get();
}
private static List<TypeInformation<?>> getTypeInfos(DataStreamList dataStreams) {
return map(dataStreams, DataStream::getType);
}
private static DataStreamList addInputs(DataStreamList dataStreams) {
return new DataStreamList(
map(
dataStreams,
dataStream ->
dataStream
.transform(
"input-" + dataStream.getTransformation().getName(),
new IterationRecordTypeInfo<>(dataStream.getType()),
new InputOperator())
.setParallelism(dataStream.getParallelism())));
}
private static DataStreamList addHeads(
DataStreamList variableStreams,
DataStreamList inputStreams,
IterationID iterationId,
int totalInitVariableParallelism,
boolean isCriteriaStream,
int startHeaderIndex) {
return new DataStreamList(
map(
inputStreams,
(index, dataStream) -> {
DataStream ds =
((SingleOutputStreamOperator<IterationRecord<?>>) dataStream)
.transform(
"head-"
+ variableStreams
.get(index)
.getTransformation()
.getName(),
(IterationRecordTypeInfo) dataStream.getType(),
new HeadOperatorFactory(
iterationId,
startHeaderIndex + index,
isCriteriaStream,
totalInitVariableParallelism))
.setParallelism(dataStream.getParallelism());
DataStreamUtils.setManagedMemoryWeight(ds, 100);
return ds;
}));
}
private static DataStreamList addTails(
DataStreamList dataStreams, IterationID iterationId, int startIndex) {
return new DataStreamList(
map(
dataStreams,
(index, dataStream) -> {
Transformation<?> inputTransformation = dataStream.getTransformation();
if (!(inputTransformation instanceof PhysicalTransformation)
&& inputTransformation.getInputs().size() > 1) {
// TODO: Support epoch watermark alignment for TailOperator.
throw new UnsupportedOperationException(
"Tail operator should have only one input. Please check whether operator \""
+ inputTransformation.getName()
+ "\" contains multiple inputs.");
}
return ((DataStream<IterationRecord<?>>) dataStream)
.transform(
"tail-" + dataStream.getTransformation().getName(),
new IterationRecordTypeInfo(dataStream.getType()),
new TailOperator(iterationId, startIndex + index))
.setParallelism(dataStream.getParallelism());
}));
}
@SuppressWarnings({"unchecked", "rawtypes"})
private static DataStreamList addOutputs(DataStreamList dataStreams, DataStream tailsUnion) {
return new DataStreamList(
map(
dataStreams,
(index, dataStream) -> {
IterationRecordTypeInfo<?> inputType =
(IterationRecordTypeInfo<?>) dataStream.getType();
return dataStream
.union(
tailsUnion
.map(x -> x)
.name(
"tail-map-"
+ dataStream
.getTransformation()
.getName())
.returns(inputType)
.setParallelism(1))
.transform(
"output-" + dataStream.getTransformation().getName(),
inputType.getInnerTypeInfo(),
new OutputOperator())
.setParallelism(dataStream.getParallelism());
}));
}
private static DataStreamList addDraftSources(
DataStreamList dataStreams,
DraftExecutionEnvironment draftEnv,
List<TypeInformation<?>> typeInfos) {
return new DataStreamList(
map(
dataStreams,
(index, dataStream) ->
draftEnv.addDraftSource(dataStream, typeInfos.get(index))));
}
private static void ensuresTransformationAdded(
DataStreamList dataStreams, DraftExecutionEnvironment draftEnv) {
map(
dataStreams,
dataStream -> {
draftEnv.addOperatorIfNotExists(dataStream.getTransformation());
return null;
});
}
private static void setCriteriaParallelism(
DataStreamList headStreams, int criteriaParallelism) {
map(
headStreams,
dataStream -> {
((HeadOperatorFactory)
((OneInputTransformation) dataStream.getTransformation())
.getOperatorFactory())
.setCriteriaStreamParallelism(criteriaParallelism);
return null;
});
}
private static DataStreamList getActualDataStreams(
DataStreamList draftStreams, DraftExecutionEnvironment draftEnv) {
return new DataStreamList(
map(draftStreams, dataStream -> draftEnv.getActualStream(dataStream.getId())));
}
private static <R> List<R> map(DataStreamList dataStreams, Function<DataStream<?>, R> mapper) {
return map(dataStreams, (i, dataStream) -> mapper.apply(dataStream));
}
private static <R> List<R> map(
DataStreamList dataStreams, BiFunction<Integer, DataStream<?>, R> mapper) {
List<R> results = new ArrayList<>(dataStreams.size());
for (int i = 0; i < dataStreams.size(); ++i) {
DataStream<?> dataStream = dataStreams.get(i);
results.add(mapper.apply(i, dataStream));
}
return results;
}
private static class CriteriaMergeProcessor extends CoProcessFunction<Object, Object, Object> {
@Override
public void processElement1(Object value, Context ctx, Collector<Object> out)
throws Exception {
// Ignores all the records from the head side-output.
}
@Override
public void processElement2(Object value, Context ctx, Collector<Object> out)
throws Exception {
// Bypasses all the records from the actual criteria stream.
out.collect(value);
}
}
}