blob: 7e6ba7b941c7f347dd61d845e175ce1ac678b239 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.core.construction;
import static org.apache.beam.runners.core.construction.PTransformTranslation.PAR_DO_TRANSFORM_URN;
import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN;
import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN;
import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN;
import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN;
import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_SPLIT_RESTRICTION_URN;
import static org.apache.beam.sdk.transforms.reflect.DoFnSignatures.getStateSpecOrThrow;
import static org.apache.beam.sdk.transforms.reflect.DoFnSignatures.getTimerSpecOrThrow;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.Parameter.Type;
import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput;
import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput.Builder;
import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator;
import org.apache.beam.runners.core.construction.PTransformTranslation.TransformTranslator;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
import org.apache.beam.sdk.transforms.ViewFn;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.Cases;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.WindowParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.TimerDeclaration;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
import org.apache.beam.sdk.util.DoFnWithExecutionInformation;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.InvalidProtocolBufferException;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
/** Utilities for interacting with {@link ParDo} instances and {@link ParDoPayload} protos. */
public class ParDoTranslation {
/** The URN for an unknown Java {@link DoFn}. */
public static final String CUSTOM_JAVA_DO_FN_URN = "beam:dofn:javasdk:0.1";
/** The URN for an unknown Java {@link ViewFn}. */
public static final String CUSTOM_JAVA_VIEW_FN_URN = "beam:viewfn:javasdk:0.1";
/** The URN for an unknown Java {@link WindowMappingFn}. */
public static final String CUSTOM_JAVA_WINDOW_MAPPING_FN_URN = "beam:windowmappingfn:javasdk:0.1";
/** A {@link TransformPayloadTranslator} for {@link ParDo}. */
public static class ParDoTranslator implements TransformTranslator<MultiOutput<?, ?>> {
public static TransformTranslator create() {
return new ParDoTranslator();
}
private ParDoTranslator() {}
@Override
public String getUrn(ParDo.MultiOutput<?, ?> transform) {
return PAR_DO_TRANSFORM_URN;
}
@Override
public boolean canTranslate(PTransform<?, ?> pTransform) {
return pTransform instanceof ParDo.MultiOutput;
}
@Override
public RunnerApi.PTransform translate(
AppliedPTransform<?, ?, ?> appliedPTransform,
List<AppliedPTransform<?, ?, ?>> subtransforms,
SdkComponents components)
throws IOException {
RunnerApi.PTransform.Builder builder =
PTransformTranslation.translateAppliedPTransform(
appliedPTransform, subtransforms, components);
AppliedPTransform<?, ?, ParDo.MultiOutput<?, ?>> appliedParDo =
(AppliedPTransform<?, ?, ParDo.MultiOutput<?, ?>>) appliedPTransform;
ParDoPayload payload = translateParDo(appliedParDo, components);
builder.setSpec(
RunnerApi.FunctionSpec.newBuilder()
.setUrn(PAR_DO_TRANSFORM_URN)
.setPayload(payload.toByteString())
.build());
builder.setEnvironmentId(components.getOnlyEnvironmentId());
String mainInputName = getMainInputName(builder, payload);
PCollection<KV<?, ?>> mainInput =
(PCollection) appliedPTransform.getInputs().get(new TupleTag(mainInputName));
// https://s.apache.org/beam-portability-timers
// Add a PCollection and coder for each timer. Also treat them as inputs and outputs.
for (String localTimerName : payload.getTimerSpecsMap().keySet()) {
PCollection<?> timerPCollection =
PCollection.createPrimitiveOutputInternal(
// Create a dummy pipeline since we don't want to modify the current
// users view of the pipeline they have constructed.
Pipeline.create(),
mainInput.getWindowingStrategy(),
mainInput.isBounded(),
KvCoder.of(
((KvCoder) mainInput.getCoder()).getKeyCoder(),
// TODO: Add support for timer payloads to the SDK
// We currently assume that all payloads are unspecified.
Timer.Coder.of(VoidCoder.of())));
timerPCollection.setName(
String.format("%s.%s", appliedPTransform.getFullName(), localTimerName));
String timerPCollectionId = components.registerPCollection(timerPCollection);
builder.putInputs(localTimerName, timerPCollectionId);
builder.putOutputs(localTimerName, timerPCollectionId);
}
return builder.build();
}
}
public static ParDoPayload translateParDo(
AppliedPTransform<?, ?, ParDo.MultiOutput<?, ?>> appliedPTransform, SdkComponents components)
throws IOException {
final ParDo.MultiOutput<?, ?> parDo = appliedPTransform.getTransform();
final Pipeline pipeline = appliedPTransform.getPipeline();
final DoFn<?, ?> doFn = parDo.getFn();
final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
// Get main input.
Set<String> allInputs =
appliedPTransform.getInputs().keySet().stream()
.map(TupleTag::getId)
.collect(Collectors.toSet());
Set<String> sideInputs =
parDo.getSideInputs().values().stream()
.map(s -> s.getTagInternal().getId())
.collect(Collectors.toSet());
Set<String> timerInputs = signature.timerDeclarations().keySet();
String mainInputName =
Iterables.getOnlyElement(Sets.difference(allInputs, Sets.union(sideInputs, timerInputs)));
PCollection<?> mainInput =
(PCollection<?>) appliedPTransform.getInputs().get(new TupleTag<>(mainInputName));
final DoFnSchemaInformation doFnSchemaInformation =
ParDo.getDoFnSchemaInformation(doFn, mainInput);
return translateParDo(parDo, doFnSchemaInformation, pipeline, components);
}
/** Translate a ParDo. */
public static ParDoPayload translateParDo(
ParDo.MultiOutput<?, ?> parDo,
DoFnSchemaInformation doFnSchemaInformation,
Pipeline pipeline,
SdkComponents components)
throws IOException {
final DoFn<?, ?> doFn = parDo.getFn();
final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
final String restrictionCoderId;
if (signature.processElement().isSplittable()) {
final Coder<?> restrictionCoder =
DoFnInvokers.invokerFor(doFn).invokeGetRestrictionCoder(pipeline.getCoderRegistry());
restrictionCoderId = components.registerCoder(restrictionCoder);
} else {
restrictionCoderId = "";
}
return payloadForParDoLike(
new ParDoLike() {
@Override
public FunctionSpec translateDoFn(SdkComponents newComponents) {
return ParDoTranslation.translateDoFn(
parDo.getFn(),
parDo.getMainOutputTag(),
parDo.getSideInputs(),
doFnSchemaInformation,
newComponents);
}
@Override
public List<RunnerApi.Parameter> translateParameters() {
return ParDoTranslation.translateParameters(
signature.processElement().extraParameters());
}
@Override
public Map<String, SideInput> translateSideInputs(SdkComponents components) {
Map<String, SideInput> sideInputs = new HashMap<>();
for (PCollectionView<?> sideInput : parDo.getSideInputs().values()) {
sideInputs.put(
sideInput.getTagInternal().getId(), translateView(sideInput, components));
}
return sideInputs;
}
@Override
public Map<String, RunnerApi.StateSpec> translateStateSpecs(SdkComponents components)
throws IOException {
Map<String, RunnerApi.StateSpec> stateSpecs = new HashMap<>();
for (Map.Entry<String, StateDeclaration> state :
signature.stateDeclarations().entrySet()) {
RunnerApi.StateSpec spec =
translateStateSpec(getStateSpecOrThrow(state.getValue(), doFn), components);
stateSpecs.put(state.getKey(), spec);
}
return stateSpecs;
}
@Override
public Map<String, RunnerApi.TimerSpec> translateTimerSpecs(SdkComponents newComponents) {
Map<String, RunnerApi.TimerSpec> timerSpecs = new HashMap<>();
for (Map.Entry<String, TimerDeclaration> timer :
signature.timerDeclarations().entrySet()) {
RunnerApi.TimerSpec spec =
translateTimerSpec(getTimerSpecOrThrow(timer.getValue(), doFn), newComponents);
timerSpecs.put(timer.getKey(), spec);
}
return timerSpecs;
}
@Override
public Map<String, RunnerApi.TimerFamilySpec> translateTimerFamilySpecs(
SdkComponents newComponents) {
Map<String, RunnerApi.TimerFamilySpec> timerFamilySpecs = new HashMap<>();
for (Map.Entry<String, DoFnSignature.TimerFamilyDeclaration> timerFamily :
signature.timerFamilyDeclarations().entrySet()) {
RunnerApi.TimerFamilySpec spec =
translateTimerFamilySpec(
DoFnSignatures.getTimerFamilySpecOrThrow(timerFamily.getValue(), doFn),
newComponents);
timerFamilySpecs.put(timerFamily.getKey(), spec);
}
return timerFamilySpecs;
}
@Override
public boolean isSplittable() {
return signature.processElement().isSplittable();
}
@Override
public String translateRestrictionCoderId(SdkComponents newComponents) {
return restrictionCoderId;
}
},
components);
}
public static List<RunnerApi.Parameter> translateParameters(List<Parameter> params) {
List<RunnerApi.Parameter> parameters = new ArrayList<>();
for (Parameter parameter : params) {
RunnerApi.Parameter protoParameter = translateParameter(parameter);
if (protoParameter != null) {
parameters.add(protoParameter);
}
}
return parameters;
}
public static DoFn<?, ?> getDoFn(ParDoPayload payload) throws InvalidProtocolBufferException {
return doFnWithExecutionInformationFromProto(payload.getDoFn()).getDoFn();
}
public static DoFn<?, ?> getDoFn(AppliedPTransform<?, ?, ?> application) throws IOException {
PTransform<?, ?> transform = application.getTransform();
if (transform instanceof ParDo.MultiOutput) {
return ((ParDo.MultiOutput<?, ?>) transform).getFn();
}
return getDoFn(getParDoPayload(application));
}
public static DoFnSchemaInformation getSchemaInformation(AppliedPTransform<?, ?, ?> application) {
try {
return getSchemaInformation(getParDoPayload(application));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static DoFnSchemaInformation getSchemaInformation(RunnerApi.PTransform pTransform) {
try {
return getSchemaInformation(getParDoPayload(pTransform));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static DoFnSchemaInformation getSchemaInformation(ParDoPayload payload) {
return doFnWithExecutionInformationFromProto(payload.getDoFn()).getSchemaInformation();
}
public static TupleTag<?> getMainOutputTag(ParDoPayload payload)
throws InvalidProtocolBufferException {
return doFnWithExecutionInformationFromProto(payload.getDoFn()).getMainOutputTag();
}
public static Map<String, PCollectionView<?>> getSideInputMapping(
AppliedPTransform<?, ?, ?> application) {
try {
return getSideInputMapping(getParDoPayload(application));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static Map<String, PCollectionView<?>> getSideInputMapping(
RunnerApi.PTransform pTransform) {
try {
return getSideInputMapping(getParDoPayload(pTransform));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static Map<String, PCollectionView<?>> getSideInputMapping(ParDoPayload payload) {
return doFnWithExecutionInformationFromProto(payload.getDoFn()).getSideInputMapping();
}
public static TupleTag<?> getMainOutputTag(AppliedPTransform<?, ?, ?> application)
throws IOException {
PTransform<?, ?> transform = application.getTransform();
if (transform instanceof ParDo.MultiOutput) {
return ((ParDo.MultiOutput<?, ?>) transform).getMainOutputTag();
}
return getMainOutputTag(getParDoPayload(application));
}
public static TupleTagList getAdditionalOutputTags(AppliedPTransform<?, ?, ?> application)
throws IOException {
PTransform<?, ?> transform = application.getTransform();
if (transform instanceof ParDo.MultiOutput) {
return ((ParDo.MultiOutput<?, ?>) transform).getAdditionalOutputTags();
}
RunnerApi.PTransform protoTransform =
PTransformTranslation.toProto(
application, SdkComponents.create(application.getPipeline().getOptions()));
ParDoPayload payload = ParDoPayload.parseFrom(protoTransform.getSpec().getPayload());
TupleTag<?> mainOutputTag = getMainOutputTag(payload);
Set<String> outputTags =
Sets.difference(
protoTransform.getOutputsMap().keySet(), Collections.singleton(mainOutputTag.getId()));
ArrayList<TupleTag<?>> additionalOutputTags = new ArrayList<>();
for (String outputTag : outputTags) {
additionalOutputTags.add(new TupleTag<>(outputTag));
}
return TupleTagList.of(additionalOutputTags);
}
public static Map<TupleTag<?>, Coder<?>> getOutputCoders(AppliedPTransform<?, ?, ?> application) {
return application.getOutputs().entrySet().stream()
.filter(e -> e.getValue() instanceof PCollection)
.collect(Collectors.toMap(e -> e.getKey(), e -> ((PCollection) e.getValue()).getCoder()));
}
public static List<PCollectionView<?>> getSideInputs(AppliedPTransform<?, ?, ?> application)
throws IOException {
PTransform<?, ?> transform = application.getTransform();
if (transform instanceof ParDo.MultiOutput) {
return ((ParDo.MultiOutput<?, ?>) transform)
.getSideInputs().values().stream().collect(Collectors.toList());
}
SdkComponents sdkComponents = SdkComponents.create(application.getPipeline().getOptions());
RunnerApi.PTransform parDoProto = PTransformTranslation.toProto(application, sdkComponents);
ParDoPayload payload = ParDoPayload.parseFrom(parDoProto.getSpec().getPayload());
List<PCollectionView<?>> views = new ArrayList<>();
RehydratedComponents components =
RehydratedComponents.forComponents(sdkComponents.toComponents());
for (Map.Entry<String, SideInput> sideInputEntry : payload.getSideInputsMap().entrySet()) {
String sideInputTag = sideInputEntry.getKey();
RunnerApi.SideInput sideInput = sideInputEntry.getValue();
PCollection<?> originalPCollection =
checkNotNull(
(PCollection<?>) application.getInputs().get(new TupleTag<>(sideInputTag)),
"no input with tag %s",
sideInputTag);
views.add(
PCollectionViewTranslation.viewFromProto(
sideInput, sideInputTag, originalPCollection, parDoProto, components));
}
return views;
}
public static RunnerApi.PCollection getMainInput(
RunnerApi.PTransform ptransform, Components components) throws IOException {
checkArgument(
PAR_DO_TRANSFORM_URN.equals(ptransform.getSpec().getUrn())
|| SPLITTABLE_PAIR_WITH_RESTRICTION_URN.equals(ptransform.getSpec().getUrn())
|| SPLITTABLE_SPLIT_RESTRICTION_URN.equals(ptransform.getSpec().getUrn())
|| SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN.equals(ptransform.getSpec().getUrn())
|| SPLITTABLE_PROCESS_ELEMENTS_URN.equals(ptransform.getSpec().getUrn())
|| SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN.equals(
ptransform.getSpec().getUrn()),
"Unexpected payload type %s",
ptransform.getSpec().getUrn());
return components.getPcollectionsOrThrow(
ptransform.getInputsOrThrow(getMainInputName(ptransform)));
}
/** Returns the name of the main input of the ptransform. */
public static String getMainInputName(RunnerApi.PTransformOrBuilder ptransform)
throws IOException {
checkArgument(
PAR_DO_TRANSFORM_URN.equals(ptransform.getSpec().getUrn())
|| SPLITTABLE_PAIR_WITH_RESTRICTION_URN.equals(ptransform.getSpec().getUrn())
|| SPLITTABLE_SPLIT_RESTRICTION_URN.equals(ptransform.getSpec().getUrn())
|| SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN.equals(ptransform.getSpec().getUrn())
|| SPLITTABLE_PROCESS_ELEMENTS_URN.equals(ptransform.getSpec().getUrn())
|| SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN.equals(
ptransform.getSpec().getUrn()),
"Unexpected payload type %s",
ptransform.getSpec().getUrn());
ParDoPayload payload = ParDoPayload.parseFrom(ptransform.getSpec().getPayload());
return getMainInputName(ptransform, payload);
}
/** Returns the name of the main input of the ptransform. */
private static String getMainInputName(
RunnerApi.PTransformOrBuilder ptransform, RunnerApi.ParDoPayload payload) {
return Iterables.getOnlyElement(
Sets.difference(
ptransform.getInputsMap().keySet(),
Sets.union(payload.getSideInputsMap().keySet(), payload.getTimerSpecsMap().keySet())));
}
/** Translate state specs. */
public static RunnerApi.StateSpec translateStateSpec(
StateSpec<?> stateSpec, final SdkComponents components) throws IOException {
final RunnerApi.StateSpec.Builder builder = RunnerApi.StateSpec.newBuilder();
return stateSpec.match(
new StateSpec.Cases<RunnerApi.StateSpec>() {
@Override
public RunnerApi.StateSpec dispatchValue(Coder<?> valueCoder) {
return builder
.setReadModifyWriteSpec(
RunnerApi.ReadModifyWriteStateSpec.newBuilder()
.setCoderId(registerCoderOrThrow(components, valueCoder)))
.build();
}
@Override
public RunnerApi.StateSpec dispatchBag(Coder<?> elementCoder) {
return builder
.setBagSpec(
RunnerApi.BagStateSpec.newBuilder()
.setElementCoderId(registerCoderOrThrow(components, elementCoder)))
.build();
}
@Override
public RunnerApi.StateSpec dispatchCombining(
Combine.CombineFn<?, ?, ?> combineFn, Coder<?> accumCoder) {
return builder
.setCombiningSpec(
RunnerApi.CombiningStateSpec.newBuilder()
.setAccumulatorCoderId(registerCoderOrThrow(components, accumCoder))
.setCombineFn(CombineTranslation.toProto(combineFn, components)))
.build();
}
@Override
public RunnerApi.StateSpec dispatchMap(Coder<?> keyCoder, Coder<?> valueCoder) {
return builder
.setMapSpec(
RunnerApi.MapStateSpec.newBuilder()
.setKeyCoderId(registerCoderOrThrow(components, keyCoder))
.setValueCoderId(registerCoderOrThrow(components, valueCoder)))
.build();
}
@Override
public RunnerApi.StateSpec dispatchSet(Coder<?> elementCoder) {
return builder
.setSetSpec(
RunnerApi.SetStateSpec.newBuilder()
.setElementCoderId(registerCoderOrThrow(components, elementCoder)))
.build();
}
});
}
@VisibleForTesting
static StateSpec<?> fromProto(RunnerApi.StateSpec stateSpec, RehydratedComponents components)
throws IOException {
switch (stateSpec.getSpecCase()) {
case READ_MODIFY_WRITE_SPEC:
return StateSpecs.value(
components.getCoder(stateSpec.getReadModifyWriteSpec().getCoderId()));
case BAG_SPEC:
return StateSpecs.bag(components.getCoder(stateSpec.getBagSpec().getElementCoderId()));
case COMBINING_SPEC:
FunctionSpec combineFnSpec = stateSpec.getCombiningSpec().getCombineFn();
if (!combineFnSpec.getUrn().equals(CombineTranslation.JAVA_SERIALIZED_COMBINE_FN_URN)) {
throw new UnsupportedOperationException(
String.format(
"Cannot create %s from non-Java %s: %s",
StateSpec.class.getSimpleName(),
Combine.CombineFn.class.getSimpleName(),
combineFnSpec.getUrn()));
}
Combine.CombineFn<?, ?, ?> combineFn =
(Combine.CombineFn<?, ?, ?>)
SerializableUtils.deserializeFromByteArray(
combineFnSpec.getPayload().toByteArray(),
Combine.CombineFn.class.getSimpleName());
// Rawtype coder cast because it is required to be a valid accumulator coder
// for the CombineFn, by construction
return StateSpecs.combining(
(Coder) components.getCoder(stateSpec.getCombiningSpec().getAccumulatorCoderId()),
combineFn);
case MAP_SPEC:
return StateSpecs.map(
components.getCoder(stateSpec.getMapSpec().getKeyCoderId()),
components.getCoder(stateSpec.getMapSpec().getValueCoderId()));
case SET_SPEC:
return StateSpecs.set(components.getCoder(stateSpec.getSetSpec().getElementCoderId()));
case SPEC_NOT_SET:
default:
throw new IllegalArgumentException(
String.format("Unknown %s: %s", RunnerApi.StateSpec.class.getName(), stateSpec));
}
}
private static String registerCoderOrThrow(SdkComponents components, Coder coder) {
try {
return components.registerCoder(coder);
} catch (IOException exc) {
throw new RuntimeException("Failure to register coder", exc);
}
}
public static RunnerApi.TimerSpec translateTimerSpec(TimerSpec timer, SdkComponents components) {
return RunnerApi.TimerSpec.newBuilder()
.setTimeDomain(translateTimeDomain(timer.getTimeDomain()))
// TODO: Add support for timer payloads to the SDK
// We currently assume that all payloads are unspecified.
.setTimerCoderId(registerCoderOrThrow(components, Timer.Coder.of(VoidCoder.of())))
.build();
}
public static RunnerApi.TimerFamilySpec translateTimerFamilySpec(
TimerSpec timer, SdkComponents components) {
return RunnerApi.TimerFamilySpec.newBuilder()
.setTimeDomain(translateTimeDomain(timer.getTimeDomain()))
.setTimerFamilyCoderId(registerCoderOrThrow(components, Timer.Coder.of(VoidCoder.of())))
.build();
}
private static RunnerApi.TimeDomain.Enum translateTimeDomain(TimeDomain timeDomain) {
switch (timeDomain) {
case EVENT_TIME:
return RunnerApi.TimeDomain.Enum.EVENT_TIME;
case PROCESSING_TIME:
return RunnerApi.TimeDomain.Enum.PROCESSING_TIME;
case SYNCHRONIZED_PROCESSING_TIME:
return RunnerApi.TimeDomain.Enum.SYNCHRONIZED_PROCESSING_TIME;
default:
throw new IllegalArgumentException("Unknown time domain");
}
}
public static FunctionSpec translateDoFn(
DoFn<?, ?> fn,
TupleTag<?> tag,
Map<String, PCollectionView<?>> sideInputMapping,
DoFnSchemaInformation doFnSchemaInformation,
SdkComponents components) {
return FunctionSpec.newBuilder()
.setUrn(CUSTOM_JAVA_DO_FN_URN)
.setPayload(
ByteString.copyFrom(
SerializableUtils.serializeToByteArray(
DoFnWithExecutionInformation.of(
fn, tag, sideInputMapping, doFnSchemaInformation))))
.build();
}
public static DoFnWithExecutionInformation doFnWithExecutionInformationFromProto(
FunctionSpec fnSpec) {
checkArgument(
fnSpec.getUrn().equals(CUSTOM_JAVA_DO_FN_URN),
"Expected %s to be %s with URN %s, but URN was %s",
DoFn.class.getSimpleName(),
FunctionSpec.class.getSimpleName(),
CUSTOM_JAVA_DO_FN_URN,
fnSpec.getUrn());
byte[] serializedFn = fnSpec.getPayload().toByteArray();
return (DoFnWithExecutionInformation)
SerializableUtils.deserializeFromByteArray(serializedFn, "Custom DoFn With Execution Info");
}
/**
* Translates a Java DoFn parameter to a proto representation.
*
* <p>Returns {@code null} rather than crashing for parameters that are not yet supported, to
* allow legacy Java-based runners to perform a proto round-trip and afterwards use {@link
* DoFnSignatures} to analyze.
*
* <p>The proto definition for parameters is provisional and those parameters that are not needed
* for portability will be removed from the enum.
*/
// Using nullability instead of optional because of shading
public static @Nullable RunnerApi.Parameter translateParameter(Parameter parameter) {
return parameter.match(
new Cases.WithDefault</* @Nullable in Java 8 */ RunnerApi.Parameter>() {
@Override
public RunnerApi.Parameter dispatch(WindowParameter p) {
return RunnerApi.Parameter.newBuilder().setType(Type.Enum.WINDOW).build();
}
@Override
public RunnerApi.Parameter dispatch(RestrictionTrackerParameter p) {
return RunnerApi.Parameter.newBuilder().setType(Type.Enum.RESTRICTION_TRACKER).build();
}
@Override
// Java 7 + findbugs limitation. The return type is nullable.
protected @Nullable RunnerApi.Parameter dispatchDefault(Parameter p) {
return null;
}
});
}
public static Map<String, SideInput> translateSideInputs(
List<PCollectionView<?>> views, SdkComponents components) {
Map<String, SideInput> sideInputs = new HashMap<>();
for (PCollectionView<?> sideInput : views) {
sideInputs.put(
sideInput.getTagInternal().getId(),
ParDoTranslation.translateView(sideInput, components));
}
return sideInputs;
}
public static SideInput translateView(PCollectionView<?> view, SdkComponents components) {
Builder builder = SideInput.newBuilder();
builder.setAccessPattern(
FunctionSpec.newBuilder().setUrn(view.getViewFn().getMaterialization().getUrn()).build());
builder.setViewFn(translateViewFn(view.getViewFn(), components));
builder.setWindowMappingFn(translateWindowMappingFn(view.getWindowMappingFn(), components));
return builder.build();
}
public static FunctionSpec translateViewFn(ViewFn<?, ?> viewFn, SdkComponents components) {
return FunctionSpec.newBuilder()
.setUrn(CUSTOM_JAVA_VIEW_FN_URN)
.setPayload(ByteString.copyFrom(SerializableUtils.serializeToByteArray(viewFn)))
.build();
}
private static <T> ParDoPayload getParDoPayload(AppliedPTransform<?, ?, ?> transform)
throws IOException {
SdkComponents components = SdkComponents.create(transform.getPipeline().getOptions());
RunnerApi.PTransform parDoPTransform =
PTransformTranslation.toProto(transform, Collections.emptyList(), components);
return getParDoPayload(parDoPTransform);
}
private static ParDoPayload getParDoPayload(RunnerApi.PTransform parDoPTransform)
throws IOException {
return ParDoPayload.parseFrom(parDoPTransform.getSpec().getPayload());
}
public static boolean usesStateOrTimers(AppliedPTransform<?, ?, ?> transform) throws IOException {
ParDoPayload payload = getParDoPayload(transform);
return payload.getStateSpecsCount() > 0
|| payload.getTimerSpecsCount() > 0
|| payload.getTimerFamilySpecsCount() > 0;
}
public static boolean isSplittable(AppliedPTransform<?, ?, ?> transform) throws IOException {
ParDoPayload payload = getParDoPayload(transform);
return payload.getSplittable();
}
public static FunctionSpec translateWindowMappingFn(
WindowMappingFn<?> windowMappingFn, SdkComponents components) {
return FunctionSpec.newBuilder()
.setUrn(CUSTOM_JAVA_WINDOW_MAPPING_FN_URN)
.setPayload(ByteString.copyFrom(SerializableUtils.serializeToByteArray(windowMappingFn)))
.build();
}
/** These methods drive to-proto translation from Java and from rehydrated ParDos. */
public interface ParDoLike {
FunctionSpec translateDoFn(SdkComponents newComponents);
List<RunnerApi.Parameter> translateParameters();
Map<String, RunnerApi.SideInput> translateSideInputs(SdkComponents components);
Map<String, RunnerApi.StateSpec> translateStateSpecs(SdkComponents components)
throws IOException;
Map<String, RunnerApi.TimerSpec> translateTimerSpecs(SdkComponents newComponents);
Map<String, RunnerApi.TimerFamilySpec> translateTimerFamilySpecs(SdkComponents newComponents);
boolean isSplittable();
String translateRestrictionCoderId(SdkComponents newComponents);
}
public static ParDoPayload payloadForParDoLike(ParDoLike parDo, SdkComponents components)
throws IOException {
return ParDoPayload.newBuilder()
.setDoFn(parDo.translateDoFn(components))
.addAllParameters(parDo.translateParameters())
.putAllStateSpecs(parDo.translateStateSpecs(components))
.putAllTimerSpecs(parDo.translateTimerSpecs(components))
.putAllTimerFamilySpecs(parDo.translateTimerFamilySpecs(components))
.putAllSideInputs(parDo.translateSideInputs(components))
.setSplittable(parDo.isSplittable())
.setRestrictionCoderId(parDo.translateRestrictionCoderId(components))
.build();
}
}