blob: 4e8088f85f8a703224b14dcee1218b1abc5589d3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.core.construction.graph;
import com.google.auto.value.AutoValue;
import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.function.Predicate;
import org.apache.beam.model.pipeline.v1.RunnerApi.Coder;
import org.apache.beam.model.pipeline.v1.RunnerApi.ComponentsOrBuilder;
import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
import org.apache.beam.model.pipeline.v1.RunnerApi.MessageWithComponents;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
import org.apache.beam.runners.core.construction.ModelCoders;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.core.construction.graph.ProtoOverrides.TransformReplacement;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
/**
* A set of transform replacements for expanding a splittable ParDo into various sub components.
*
* <p>Further details about the expansion can be found at <a
* href="https://github.com/apache/beam/blob/cb15994d5228f729dda922419b08520c8be8804e/model/pipeline/src/main/proto/beam_runner_api.proto#L279"
* />
*/
public class SplittableParDoExpander {
/**
* Returns a transform replacement which expands a splittable ParDo from:
*
* <pre>{@code
* sideInputA ---------\
* sideInputB ---------V
* mainInput ---> SplittableParDo --> outputA
* \-> outputB
* }</pre>
*
* into:
*
* <pre>{@code
* sideInputA ---------\---------------------\--------------------------\
* sideInputB ---------V---------------------V--------------------------V
* mainInput ---> PairWithRestricton --> SplitAndSize --> ProcessSizedElementsAndRestriction --> outputA
* \-> outputB
* }</pre>
*
* <p>Specifically this transform ensures that initial splitting is performed and that the sizing
* information is available to the runner if it chooses to inspect it.
*/
public static TransformReplacement createSizedReplacement() {
return SizedReplacement.builder().setDrain(false).build();
}
/**
* Returns a transform replacement in drain mode which expands a splittable ParDo from:
*
* <pre>{@code
* sideInputA ---------\
* sideInputB ---------V
* mainInput ---> SplittableParDo --> outputA
* \-> outputB
* }</pre>
*
* into:
*
* <pre>{@code
* sideInputA ---------\---------------------\----------------------\--------------------------\
* sideInputB ---------V---------------------V----------------------V--------------------------V
* mainInput ---> PairWithRestriction --> SplitAndSize --> TruncateAndSize --> ProcessSizedElementsAndRestriction --> outputA
* \-> outputB
* }</pre>
*
* .
*/
public static TransformReplacement createTruncateReplacement() {
return SizedReplacement.builder().setDrain(true).build();
}
/** See {@link #createSizedReplacement()} for details. */
@AutoValue
abstract static class SizedReplacement implements TransformReplacement {
static Builder builder() {
return new AutoValue_SplittableParDoExpander_SizedReplacement.Builder();
}
abstract boolean isDrain();
@AutoValue.Builder
abstract static class Builder {
abstract Builder setDrain(boolean isDrain);
abstract SizedReplacement build();
}
@Override
public MessageWithComponents getReplacement(
String transformId, ComponentsOrBuilder existingComponents) {
try {
MessageWithComponents.Builder rval = MessageWithComponents.newBuilder();
PTransform splittableParDo = existingComponents.getTransformsOrThrow(transformId);
ParDoPayload payload = ParDoPayload.parseFrom(splittableParDo.getSpec().getPayload());
// Only perform the expansion if this is a splittable DoFn.
if (payload.getRestrictionCoderId() == null || payload.getRestrictionCoderId().isEmpty()) {
return null;
}
String mainInputName = ParDoTranslation.getMainInputName(splittableParDo);
String mainInputPCollectionId = splittableParDo.getInputsOrThrow(mainInputName);
PCollection mainInputPCollection =
existingComponents.getPcollectionsOrThrow(mainInputPCollectionId);
Map<String, String> sideInputs =
Maps.filterKeys(
splittableParDo.getInputsMap(), input -> payload.containsSideInputs(input));
String pairWithRestrictionOutCoderId =
generateUniqueId(
mainInputPCollection.getCoderId() + "/PairWithRestriction",
existingComponents::containsCoders);
rval.getComponentsBuilder()
.putCoders(
pairWithRestrictionOutCoderId,
ModelCoders.kvCoder(
mainInputPCollection.getCoderId(), payload.getRestrictionCoderId()));
String pairWithRestrictionOutId =
generateUniqueId(
mainInputPCollectionId + "/PairWithRestriction",
existingComponents::containsPcollections);
rval.getComponentsBuilder()
.putPcollections(
pairWithRestrictionOutId,
PCollection.newBuilder()
.setCoderId(pairWithRestrictionOutCoderId)
.setIsBounded(mainInputPCollection.getIsBounded())
.setWindowingStrategyId(mainInputPCollection.getWindowingStrategyId())
.setUniqueName(
generateUniquePCollectonName(
mainInputPCollection.getUniqueName() + "/PairWithRestriction",
existingComponents))
.build());
String splitAndSizeOutCoderId =
generateUniqueId(
mainInputPCollection.getCoderId() + "/SplitAndSize",
existingComponents::containsCoders);
rval.getComponentsBuilder()
.putCoders(
splitAndSizeOutCoderId,
ModelCoders.kvCoder(
pairWithRestrictionOutCoderId, getOrAddDoubleCoder(existingComponents, rval)));
String splitAndSizeOutId =
generateUniqueId(
mainInputPCollectionId + "/SplitAndSize", existingComponents::containsPcollections);
rval.getComponentsBuilder()
.putPcollections(
splitAndSizeOutId,
PCollection.newBuilder()
.setCoderId(splitAndSizeOutCoderId)
.setIsBounded(mainInputPCollection.getIsBounded())
.setWindowingStrategyId(mainInputPCollection.getWindowingStrategyId())
.setUniqueName(
generateUniquePCollectonName(
mainInputPCollection.getUniqueName() + "/SplitAndSize",
existingComponents))
.build());
String pairWithRestrictionId =
generateUniqueId(
transformId + "/PairWithRestriction", existingComponents::containsTransforms);
{
PTransform.Builder pairWithRestriction = PTransform.newBuilder();
pairWithRestriction.putAllInputs(splittableParDo.getInputsMap());
pairWithRestriction.putOutputs("out", pairWithRestrictionOutId);
pairWithRestriction.setUniqueName(
generateUniquePCollectonName(
splittableParDo.getUniqueName() + "/PairWithRestriction", existingComponents));
pairWithRestriction.setSpec(
FunctionSpec.newBuilder()
.setUrn(PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN)
.setPayload(splittableParDo.getSpec().getPayload()));
pairWithRestriction.setEnvironmentId(splittableParDo.getEnvironmentId());
rval.getComponentsBuilder()
.putTransforms(pairWithRestrictionId, pairWithRestriction.build());
}
String splitAndSizeId =
generateUniqueId(transformId + "/SplitAndSize", existingComponents::containsTransforms);
{
PTransform.Builder splitAndSize = PTransform.newBuilder();
splitAndSize.putInputs(mainInputName, pairWithRestrictionOutId);
splitAndSize.putAllInputs(sideInputs);
splitAndSize.putOutputs("out", splitAndSizeOutId);
splitAndSize.setUniqueName(
generateUniquePCollectonName(
splittableParDo.getUniqueName() + "/SplitAndSize", existingComponents));
splitAndSize.setSpec(
FunctionSpec.newBuilder()
.setUrn(PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN)
.setPayload(splittableParDo.getSpec().getPayload()));
splitAndSize.setEnvironmentId(splittableParDo.getEnvironmentId());
rval.getComponentsBuilder().putTransforms(splitAndSizeId, splitAndSize.build());
}
PTransform.Builder newCompositeRoot =
splittableParDo
.toBuilder()
// Clear the original splittable ParDo spec and add all the new transforms as
// children.
.clearSpec()
.addAllSubtransforms(Arrays.asList(pairWithRestrictionId, splitAndSizeId));
String processSizedElementsAndRestrictionsId =
generateUniqueId(
transformId + "/ProcessSizedElementsAndRestrictions",
existingComponents::containsTransforms);
String processSizedElementsInputPCollectionId = splitAndSizeOutId;
if (isDrain()) {
String truncateAndSizeCoderId =
generateUniqueId(
mainInputPCollection.getCoderId() + "/TruncateAndSize",
existingComponents::containsCoders);
rval.getComponentsBuilder()
.putCoders(
truncateAndSizeCoderId,
ModelCoders.kvCoder(
splitAndSizeOutCoderId, getOrAddDoubleCoder(existingComponents, rval)));
String truncateAndSizeOutId =
generateUniqueId(
mainInputPCollectionId + "/TruncateAndSize",
existingComponents::containsPcollections);
rval.getComponentsBuilder()
.putPcollections(
truncateAndSizeOutId,
PCollection.newBuilder()
.setCoderId(truncateAndSizeCoderId)
.setIsBounded(mainInputPCollection.getIsBounded())
.setWindowingStrategyId(mainInputPCollection.getWindowingStrategyId())
.setUniqueName(
generateUniquePCollectonName(
mainInputPCollection.getUniqueName() + "/TruncateAndSize",
existingComponents))
.build());
String truncateAndSizeId =
generateUniqueId(
transformId + "/TruncateAndSize", existingComponents::containsTransforms);
{
PTransform.Builder truncateAndSize = PTransform.newBuilder();
truncateAndSize.putInputs(mainInputName, splitAndSizeOutId);
truncateAndSize.putAllInputs(sideInputs);
truncateAndSize.putOutputs("out", truncateAndSizeOutId);
truncateAndSize.setUniqueName(
generateUniquePCollectonName(
splittableParDo.getUniqueName() + "/TruncateAndSize", existingComponents));
truncateAndSize.setSpec(
FunctionSpec.newBuilder()
.setUrn(PTransformTranslation.SPLITTABLE_TRUNCATE_SIZED_RESTRICTION_URN)
.setPayload(splittableParDo.getSpec().getPayload()));
truncateAndSize.setEnvironmentId(splittableParDo.getEnvironmentId());
rval.getComponentsBuilder().putTransforms(truncateAndSizeId, truncateAndSize.build());
}
newCompositeRoot.addSubtransforms(truncateAndSizeId);
processSizedElementsInputPCollectionId = truncateAndSizeOutId;
}
{
PTransform.Builder processSizedElementsAndRestrictions = PTransform.newBuilder();
processSizedElementsAndRestrictions.putInputs(
mainInputName, processSizedElementsInputPCollectionId);
processSizedElementsAndRestrictions.putAllInputs(sideInputs);
processSizedElementsAndRestrictions.putAllOutputs(splittableParDo.getOutputsMap());
processSizedElementsAndRestrictions.setUniqueName(
generateUniquePCollectonName(
splittableParDo.getUniqueName() + "/ProcessSizedElementsAndRestrictions",
existingComponents));
processSizedElementsAndRestrictions.setSpec(
FunctionSpec.newBuilder()
.setUrn(
PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)
.setPayload(splittableParDo.getSpec().getPayload()));
processSizedElementsAndRestrictions.setEnvironmentId(splittableParDo.getEnvironmentId());
rval.getComponentsBuilder()
.putTransforms(
processSizedElementsAndRestrictionsId,
processSizedElementsAndRestrictions.build());
}
newCompositeRoot.addSubtransforms(processSizedElementsAndRestrictionsId);
rval.setPtransform(newCompositeRoot);
return rval.build();
} catch (IOException e) {
throw new RuntimeException("Unable to perform expansion for transform " + transformId, e);
}
}
}
private static String getOrAddDoubleCoder(
ComponentsOrBuilder existingComponents, MessageWithComponents.Builder out) {
for (Map.Entry<String, Coder> coder : existingComponents.getCodersMap().entrySet()) {
if (ModelCoders.DOUBLE_CODER_URN.equals(coder.getValue().getSpec().getUrn())) {
return coder.getKey();
}
}
String doubleCoderId = generateUniqueId("DoubleCoder", existingComponents::containsCoders);
out.getComponentsBuilder()
.putCoders(
doubleCoderId,
Coder.newBuilder()
.setSpec(FunctionSpec.newBuilder().setUrn(ModelCoders.DOUBLE_CODER_URN))
.build());
return doubleCoderId;
}
/**
* Returns a PCollection name that uses the supplied prefix that does not exist in {@code
* existingComponents}.
*/
private static String generateUniquePCollectonName(
String prefix, ComponentsOrBuilder existingComponents) {
return generateUniqueId(
prefix,
input -> {
for (PCollection pc : existingComponents.getPcollectionsMap().values()) {
if (input.equals(pc.getUniqueName())) {
return true;
}
}
return false;
});
}
/** Generates a unique id given a prefix and a predicate to compare if the id is already used. */
private static String generateUniqueId(String prefix, Predicate<String> isExistingId) {
int i = 0;
while (isExistingId.test(prefix + i)) {
i += 1;
}
return prefix + i;
}
}