blob: f88cbe56defa6677084616f30cbc69056452d6ae [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.core.construction;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static org.apache.beam.runners.core.construction.PTransformTranslation.PAR_DO_TRANSFORM_URN;
import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Optional;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.Parameter.Type;
import org.apache.beam.model.pipeline.v1.RunnerApi.SdkFunctionSpec;
import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput;
import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput.Builder;
import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Materializations;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
import org.apache.beam.sdk.transforms.ViewFn;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.Cases;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.WindowParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.TimerDeclaration;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.sdk.values.WindowingStrategy;
/** Utilities for interacting with {@link ParDo} instances and {@link ParDoPayload} protos. */
public class ParDoTranslation {
/** The URN for an unknown Java {@link DoFn}. */
public static final String CUSTOM_JAVA_DO_FN_URN = "urn:beam:dofn:javasdk:0.1";
/** The URN for an unknown Java {@link ViewFn}. */
public static final String CUSTOM_JAVA_VIEW_FN_URN = "urn:beam:viewfn:javasdk:0.1";
/** The URN for an unknown Java {@link WindowMappingFn}. */
public static final String CUSTOM_JAVA_WINDOW_MAPPING_FN_URN =
"urn:beam:windowmappingfn:javasdk:0.1";
/** A {@link TransformPayloadTranslator} for {@link ParDo}. */
public static class ParDoPayloadTranslator
implements TransformPayloadTranslator<MultiOutput<?, ?>> {
public static TransformPayloadTranslator create() {
return new ParDoPayloadTranslator();
}
private ParDoPayloadTranslator() {}
@Override
public String getUrn(ParDo.MultiOutput<?, ?> transform) {
return PAR_DO_TRANSFORM_URN;
}
@Override
public FunctionSpec translate(
AppliedPTransform<?, ?, MultiOutput<?, ?>> transform, SdkComponents components)
throws IOException {
ParDoPayload payload = toProto(transform.getTransform(), components);
return RunnerApi.FunctionSpec.newBuilder()
.setUrn(PAR_DO_TRANSFORM_URN)
.setPayload(payload.toByteString())
.build();
}
@Override
public PTransformTranslation.RawPTransform<?, ?> rehydrate(
RunnerApi.PTransform protoTransform, RehydratedComponents rehydratedComponents)
throws IOException {
return new RawParDo<>(protoTransform, rehydratedComponents);
}
/** Registers {@link ParDoPayloadTranslator}. */
@AutoService(TransformPayloadTranslatorRegistrar.class)
public static class Registrar implements TransformPayloadTranslatorRegistrar {
@Override
public Map<? extends Class<? extends PTransform>, ? extends TransformPayloadTranslator>
getTransformPayloadTranslators() {
return Collections.singletonMap(ParDo.MultiOutput.class, new ParDoPayloadTranslator());
}
@Override
public Map<String, ? extends TransformPayloadTranslator> getTransformRehydrators() {
return Collections.singletonMap(PAR_DO_TRANSFORM_URN, new ParDoPayloadTranslator());
}
}
}
public static ParDoPayload toProto(final ParDo.MultiOutput<?, ?> parDo, SdkComponents components)
throws IOException {
final DoFn<?, ?> doFn = parDo.getFn();
final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
return payloadForParDoLike(
new ParDoLike() {
@Override
public SdkFunctionSpec translateDoFn(SdkComponents newComponents) {
return toProto(parDo.getFn(), parDo.getMainOutputTag());
}
@Override
public List<RunnerApi.Parameter> translateParameters() {
List<RunnerApi.Parameter> parameters = new ArrayList<>();
for (Parameter parameter : signature.processElement().extraParameters()) {
Optional<RunnerApi.Parameter> protoParameter = toProto(parameter);
if (protoParameter.isPresent()) {
parameters.add(protoParameter.get());
}
}
return parameters;
}
@Override
public Map<String, SideInput> translateSideInputs(SdkComponents components) {
Map<String, SideInput> sideInputs = new HashMap<>();
for (PCollectionView<?> sideInput : parDo.getSideInputs()) {
sideInputs.put(sideInput.getTagInternal().getId(), toProto(sideInput));
}
return sideInputs;
}
@Override
public Map<String, RunnerApi.StateSpec> translateStateSpecs(SdkComponents components)
throws IOException {
Map<String, RunnerApi.StateSpec> stateSpecs = new HashMap<>();
for (Map.Entry<String, StateDeclaration> state :
signature.stateDeclarations().entrySet()) {
RunnerApi.StateSpec spec =
toProto(getStateSpecOrCrash(state.getValue(), doFn), components);
stateSpecs.put(state.getKey(), spec);
}
return stateSpecs;
}
@Override
public Map<String, RunnerApi.TimerSpec> translateTimerSpecs(SdkComponents newComponents) {
Map<String, RunnerApi.TimerSpec> timerSpecs = new HashMap<>();
for (Map.Entry<String, TimerDeclaration> timer :
signature.timerDeclarations().entrySet()) {
RunnerApi.TimerSpec spec = toProto(getTimerSpecOrCrash(timer.getValue(), doFn));
timerSpecs.put(timer.getKey(), spec);
}
return timerSpecs;
}
@Override
public boolean isSplittable() {
return signature.processElement().isSplittable();
}
},
components);
}
private static StateSpec<?> getStateSpecOrCrash(
StateDeclaration stateDeclaration, DoFn<?, ?> target) {
try {
Object fieldValue = stateDeclaration.field().get(target);
checkState(
fieldValue instanceof StateSpec,
"Malformed %s class %s: state declaration field %s does not have type %s.",
DoFn.class.getSimpleName(),
target.getClass().getName(),
stateDeclaration.field().getName(),
StateSpec.class);
return (StateSpec<?>) stateDeclaration.field().get(target);
} catch (IllegalAccessException exc) {
throw new RuntimeException(
String.format(
"Malformed %s class %s: state declaration field %s is not accessible.",
DoFn.class.getSimpleName(),
target.getClass().getName(),
stateDeclaration.field().getName()));
}
}
private static TimerSpec getTimerSpecOrCrash(
TimerDeclaration timerDeclaration, DoFn<?, ?> target) {
try {
Object fieldValue = timerDeclaration.field().get(target);
checkState(
fieldValue instanceof TimerSpec,
"Malformed %s class %s: timer declaration field %s does not have type %s.",
DoFn.class.getSimpleName(),
target.getClass().getName(),
timerDeclaration.field().getName(),
TimerSpec.class);
return (TimerSpec) timerDeclaration.field().get(target);
} catch (IllegalAccessException exc) {
throw new RuntimeException(
String.format(
"Malformed %s class %s: timer declaration field %s is not accessible.",
DoFn.class.getSimpleName(),
target.getClass().getName(),
timerDeclaration.field().getName()));
}
}
public static DoFn<?, ?> getDoFn(ParDoPayload payload) throws InvalidProtocolBufferException {
return doFnAndMainOutputTagFromProto(payload.getDoFn()).getDoFn();
}
public static DoFn<?, ?> getDoFn(AppliedPTransform<?, ?, ?> application) throws IOException {
PTransform<?, ?> transform = application.getTransform();
if (transform instanceof ParDo.MultiOutput) {
return ((ParDo.MultiOutput<?, ?>) transform).getFn();
}
return getDoFn(getParDoPayload(application));
}
public static TupleTag<?> getMainOutputTag(ParDoPayload payload)
throws InvalidProtocolBufferException {
return doFnAndMainOutputTagFromProto(payload.getDoFn()).getMainOutputTag();
}
public static TupleTag<?> getMainOutputTag(AppliedPTransform<?, ?, ?> application)
throws IOException {
PTransform<?, ?> transform = application.getTransform();
if (transform instanceof ParDo.MultiOutput) {
return ((ParDo.MultiOutput<?, ?>) transform).getMainOutputTag();
}
return getMainOutputTag(getParDoPayload(application));
}
public static TupleTagList getAdditionalOutputTags(AppliedPTransform<?, ?, ?> application)
throws IOException {
PTransform<?, ?> transform = application.getTransform();
if (transform instanceof ParDo.MultiOutput) {
return ((ParDo.MultiOutput<?, ?>) transform).getAdditionalOutputTags();
}
RunnerApi.PTransform protoTransform =
PTransformTranslation.toProto(application, SdkComponents.create());
ParDoPayload payload = ParDoPayload.parseFrom(protoTransform.getSpec().getPayload());
TupleTag<?> mainOutputTag = getMainOutputTag(payload);
Set<String> outputTags =
Sets.difference(
protoTransform.getOutputsMap().keySet(), Collections.singleton(mainOutputTag.getId()));
ArrayList<TupleTag<?>> additionalOutputTags = new ArrayList<>();
for (String outputTag : outputTags) {
additionalOutputTags.add(new TupleTag<>(outputTag));
}
return TupleTagList.of(additionalOutputTags);
}
public static List<PCollectionView<?>> getSideInputs(AppliedPTransform<?, ?, ?> application)
throws IOException {
PTransform<?, ?> transform = application.getTransform();
if (transform instanceof ParDo.MultiOutput) {
return ((ParDo.MultiOutput<?, ?>) transform).getSideInputs();
}
SdkComponents sdkComponents = SdkComponents.create();
RunnerApi.PTransform parDoProto = PTransformTranslation.toProto(application, sdkComponents);
ParDoPayload payload = ParDoPayload.parseFrom(parDoProto.getSpec().getPayload());
List<PCollectionView<?>> views = new ArrayList<>();
RehydratedComponents components =
RehydratedComponents.forComponents(sdkComponents.toComponents());
for (Map.Entry<String, SideInput> sideInputEntry : payload.getSideInputsMap().entrySet()) {
String sideInputTag = sideInputEntry.getKey();
RunnerApi.SideInput sideInput = sideInputEntry.getValue();
PCollection<?> originalPCollection =
checkNotNull(
(PCollection<?>) application.getInputs().get(new TupleTag<>(sideInputTag)),
"no input with tag %s",
sideInputTag);
views.add(
viewFromProto(sideInput, sideInputTag, originalPCollection, parDoProto, components));
}
return views;
}
public static RunnerApi.PCollection getMainInput(
RunnerApi.PTransform ptransform, Components components) throws IOException {
checkArgument(
ptransform.getSpec().getUrn().equals(PAR_DO_TRANSFORM_URN),
"Unexpected payload type %s",
ptransform.getSpec().getUrn());
ParDoPayload payload = ParDoPayload.parseFrom(ptransform.getSpec().getPayload());
String mainInputId =
Iterables.getOnlyElement(
Sets.difference(
ptransform.getInputsMap().keySet(), payload.getSideInputsMap().keySet()));
return components.getPcollectionsOrThrow(ptransform.getInputsOrThrow(mainInputId));
}
@VisibleForTesting
static RunnerApi.StateSpec toProto(StateSpec<?> stateSpec, final SdkComponents components)
throws IOException {
final RunnerApi.StateSpec.Builder builder = RunnerApi.StateSpec.newBuilder();
return stateSpec.match(
new StateSpec.Cases<RunnerApi.StateSpec>() {
@Override
public RunnerApi.StateSpec dispatchValue(Coder<?> valueCoder) {
return builder
.setValueSpec(
RunnerApi.ValueStateSpec.newBuilder()
.setCoderId(registerCoderOrThrow(components, valueCoder)))
.build();
}
@Override
public RunnerApi.StateSpec dispatchBag(Coder<?> elementCoder) {
return builder
.setBagSpec(
RunnerApi.BagStateSpec.newBuilder()
.setElementCoderId(registerCoderOrThrow(components, elementCoder)))
.build();
}
@Override
public RunnerApi.StateSpec dispatchCombining(
Combine.CombineFn<?, ?, ?> combineFn, Coder<?> accumCoder) {
return builder
.setCombiningSpec(
RunnerApi.CombiningStateSpec.newBuilder()
.setAccumulatorCoderId(registerCoderOrThrow(components, accumCoder))
.setCombineFn(CombineTranslation.toProto(combineFn)))
.build();
}
@Override
public RunnerApi.StateSpec dispatchMap(Coder<?> keyCoder, Coder<?> valueCoder) {
return builder
.setMapSpec(
RunnerApi.MapStateSpec.newBuilder()
.setKeyCoderId(registerCoderOrThrow(components, keyCoder))
.setValueCoderId(registerCoderOrThrow(components, valueCoder)))
.build();
}
@Override
public RunnerApi.StateSpec dispatchSet(Coder<?> elementCoder) {
return builder
.setSetSpec(
RunnerApi.SetStateSpec.newBuilder()
.setElementCoderId(registerCoderOrThrow(components, elementCoder)))
.build();
}
});
}
@VisibleForTesting
static StateSpec<?> fromProto(RunnerApi.StateSpec stateSpec, RehydratedComponents components)
throws IOException {
switch (stateSpec.getSpecCase()) {
case VALUE_SPEC:
return StateSpecs.value(components.getCoder(stateSpec.getValueSpec().getCoderId()));
case BAG_SPEC:
return StateSpecs.bag(components.getCoder(stateSpec.getBagSpec().getElementCoderId()));
case COMBINING_SPEC:
FunctionSpec combineFnSpec = stateSpec.getCombiningSpec().getCombineFn().getSpec();
if (!combineFnSpec.getUrn().equals(CombineTranslation.JAVA_SERIALIZED_COMBINE_FN_URN)) {
throw new UnsupportedOperationException(
String.format(
"Cannot create %s from non-Java %s: %s",
StateSpec.class.getSimpleName(),
Combine.CombineFn.class.getSimpleName(),
combineFnSpec.getUrn()));
}
Combine.CombineFn<?, ?, ?> combineFn =
(Combine.CombineFn<?, ?, ?>)
SerializableUtils.deserializeFromByteArray(
combineFnSpec.getPayload().toByteArray(),
Combine.CombineFn.class.getSimpleName());
// Rawtype coder cast because it is required to be a valid accumulator coder
// for the CombineFn, by construction
return StateSpecs.combining(
(Coder) components.getCoder(stateSpec.getCombiningSpec().getAccumulatorCoderId()),
combineFn);
case MAP_SPEC:
return StateSpecs.map(
components.getCoder(stateSpec.getMapSpec().getKeyCoderId()),
components.getCoder(stateSpec.getMapSpec().getValueCoderId()));
case SET_SPEC:
return StateSpecs.set(components.getCoder(stateSpec.getSetSpec().getElementCoderId()));
case SPEC_NOT_SET:
default:
throw new IllegalArgumentException(
String.format("Unknown %s: %s", RunnerApi.StateSpec.class.getName(), stateSpec));
}
}
private static String registerCoderOrThrow(SdkComponents components, Coder coder) {
try {
return components.registerCoder(coder);
} catch (IOException exc) {
throw new RuntimeException("Failure to register coder", exc);
}
}
private static RunnerApi.TimerSpec toProto(TimerSpec timer) {
return RunnerApi.TimerSpec.newBuilder().setTimeDomain(toProto(timer.getTimeDomain())).build();
}
private static RunnerApi.TimeDomain.Enum toProto(TimeDomain timeDomain) {
switch (timeDomain) {
case EVENT_TIME:
return RunnerApi.TimeDomain.Enum.EVENT_TIME;
case PROCESSING_TIME:
return RunnerApi.TimeDomain.Enum.PROCESSING_TIME;
case SYNCHRONIZED_PROCESSING_TIME:
return RunnerApi.TimeDomain.Enum.SYNCHRONIZED_PROCESSING_TIME;
default:
throw new IllegalArgumentException("Unknown time domain");
}
}
@AutoValue
abstract static class DoFnAndMainOutput implements Serializable {
public static DoFnAndMainOutput of(DoFn<?, ?> fn, TupleTag<?> tag) {
return new AutoValue_ParDoTranslation_DoFnAndMainOutput(fn, tag);
}
abstract DoFn<?, ?> getDoFn();
abstract TupleTag<?> getMainOutputTag();
}
private static SdkFunctionSpec toProto(DoFn<?, ?> fn, TupleTag<?> tag) {
return SdkFunctionSpec.newBuilder()
.setSpec(
FunctionSpec.newBuilder()
.setUrn(CUSTOM_JAVA_DO_FN_URN)
.setPayload(
ByteString.copyFrom(
SerializableUtils.serializeToByteArray(DoFnAndMainOutput.of(fn, tag))))
.build())
.build();
}
private static DoFnAndMainOutput doFnAndMainOutputTagFromProto(SdkFunctionSpec fnSpec)
throws InvalidProtocolBufferException {
checkArgument(
fnSpec.getSpec().getUrn().equals(CUSTOM_JAVA_DO_FN_URN),
"Expected %s to be %s with URN %s, but URN was %s",
DoFn.class.getSimpleName(),
FunctionSpec.class.getSimpleName(),
CUSTOM_JAVA_DO_FN_URN,
fnSpec.getSpec().getUrn());
byte[] serializedFn = fnSpec.getSpec().getPayload().toByteArray();
return (DoFnAndMainOutput)
SerializableUtils.deserializeFromByteArray(serializedFn, "Custom DoFn And Main Output tag");
}
private static Optional<RunnerApi.Parameter> toProto(Parameter parameter) {
return parameter.match(
new Cases.WithDefault<Optional<RunnerApi.Parameter>>() {
@Override
public Optional<RunnerApi.Parameter> dispatch(WindowParameter p) {
return Optional.of(RunnerApi.Parameter.newBuilder().setType(Type.Enum.WINDOW).build());
}
@Override
public Optional<RunnerApi.Parameter> dispatch(RestrictionTrackerParameter p) {
return Optional.of(
RunnerApi.Parameter.newBuilder().setType(Type.Enum.RESTRICTION_TRACKER).build());
}
@Override
protected Optional<RunnerApi.Parameter> dispatchDefault(Parameter p) {
return Optional.absent();
}
});
}
public static SideInput toProto(PCollectionView<?> view) {
Builder builder = SideInput.newBuilder();
builder.setAccessPattern(
FunctionSpec.newBuilder().setUrn(view.getViewFn().getMaterialization().getUrn()).build());
builder.setViewFn(toProto(view.getViewFn()));
builder.setWindowMappingFn(toProto(view.getWindowMappingFn()));
return builder.build();
}
/**
* Create a {@link PCollectionView} from a side input spec and an already-deserialized {@link
* PCollection} that should be wired up.
*/
public static PCollectionView<?> viewFromProto(
SideInput sideInput,
String localName,
PCollection<?> pCollection,
RunnerApi.PTransform parDoTransform,
RehydratedComponents components)
throws IOException {
checkArgument(
localName != null,
"%s.viewFromProto: localName must not be null",
ParDoTranslation.class.getSimpleName());
TupleTag<?> tag = new TupleTag<>(localName);
WindowMappingFn<?> windowMappingFn = windowMappingFnFromProto(sideInput.getWindowMappingFn());
ViewFn<?, ?> viewFn = viewFnFromProto(sideInput.getViewFn());
WindowingStrategy<?, ?> windowingStrategy = pCollection.getWindowingStrategy().fixDefaults();
Coder<Iterable<WindowedValue<?>>> coder =
(Coder)
IterableCoder.of(
FullWindowedValueCoder.of(
pCollection.getCoder(),
pCollection.getWindowingStrategy().getWindowFn().windowCoder()));
checkArgument(
sideInput.getAccessPattern().getUrn().equals(Materializations.ITERABLE_MATERIALIZATION_URN),
"Unknown View Materialization URN %s",
sideInput.getAccessPattern().getUrn());
PCollectionView<?> view =
new RunnerPCollectionView<>(
pCollection,
(TupleTag<Iterable<WindowedValue<?>>>) tag,
(ViewFn<Iterable<WindowedValue<?>>, ?>) viewFn,
windowMappingFn,
windowingStrategy,
coder);
return view;
}
private static SdkFunctionSpec toProto(ViewFn<?, ?> viewFn) {
return SdkFunctionSpec.newBuilder()
.setSpec(
FunctionSpec.newBuilder()
.setUrn(CUSTOM_JAVA_VIEW_FN_URN)
.setPayload(ByteString.copyFrom(SerializableUtils.serializeToByteArray(viewFn)))
.build())
.build();
}
private static <T> ParDoPayload getParDoPayload(AppliedPTransform<?, ?, ?> transform)
throws IOException {
RunnerApi.PTransform parDoPTransform =
PTransformTranslation.toProto(
transform, Collections.<AppliedPTransform<?, ?, ?>>emptyList(), SdkComponents.create());
return ParDoPayload.parseFrom(parDoPTransform.getSpec().getPayload());
}
public static boolean usesStateOrTimers(AppliedPTransform<?, ?, ?> transform) throws IOException {
ParDoPayload payload = getParDoPayload(transform);
return payload.getStateSpecsCount() > 0 || payload.getTimerSpecsCount() > 0;
}
public static boolean isSplittable(AppliedPTransform<?, ?, ?> transform) throws IOException {
ParDoPayload payload = getParDoPayload(transform);
return payload.getSplittable();
}
private static ViewFn<?, ?> viewFnFromProto(SdkFunctionSpec viewFn)
throws InvalidProtocolBufferException {
FunctionSpec spec = viewFn.getSpec();
checkArgument(
spec.getUrn().equals(CUSTOM_JAVA_VIEW_FN_URN),
"Can't deserialize unknown %s type %s",
ViewFn.class.getSimpleName(),
spec.getUrn());
return (ViewFn<?, ?>)
SerializableUtils.deserializeFromByteArray(
spec.getPayload().toByteArray(), "Custom ViewFn");
}
private static SdkFunctionSpec toProto(WindowMappingFn<?> windowMappingFn) {
return SdkFunctionSpec.newBuilder()
.setSpec(
FunctionSpec.newBuilder()
.setUrn(CUSTOM_JAVA_WINDOW_MAPPING_FN_URN)
.setPayload(
ByteString.copyFrom(SerializableUtils.serializeToByteArray(windowMappingFn)))
.build())
.build();
}
private static WindowMappingFn<?> windowMappingFnFromProto(SdkFunctionSpec windowMappingFn)
throws InvalidProtocolBufferException {
FunctionSpec spec = windowMappingFn.getSpec();
checkArgument(
spec.getUrn().equals(CUSTOM_JAVA_WINDOW_MAPPING_FN_URN),
"Can't deserialize unknown %s type %s",
WindowMappingFn.class.getSimpleName(),
spec.getUrn());
return (WindowMappingFn<?>)
SerializableUtils.deserializeFromByteArray(
spec.getPayload().toByteArray(), "Custom WinodwMappingFn");
}
static class RawParDo<InputT, OutputT>
extends PTransformTranslation.RawPTransform<PCollection<InputT>, PCollection<OutputT>>
implements ParDoLike {
private final RunnerApi.PTransform protoTransform;
private final transient RehydratedComponents rehydratedComponents;
// Parsed from protoTransform and cached
private final FunctionSpec spec;
private final ParDoPayload payload;
public RawParDo(RunnerApi.PTransform protoTransform, RehydratedComponents rehydratedComponents)
throws IOException {
this.rehydratedComponents = rehydratedComponents;
this.protoTransform = protoTransform;
this.spec = protoTransform.getSpec();
this.payload = ParDoPayload.parseFrom(spec.getPayload());
}
@Override
public FunctionSpec getSpec() {
return spec;
}
@Override
public FunctionSpec migrate(SdkComponents components) throws IOException {
return FunctionSpec.newBuilder()
.setUrn(PAR_DO_TRANSFORM_URN)
.setPayload(payloadForParDoLike(this, components).toByteString())
.build();
}
@Override
public Map<TupleTag<?>, PValue> getAdditionalInputs() {
Map<TupleTag<?>, PValue> additionalInputs = new HashMap<>();
for (Map.Entry<String, SideInput> sideInputEntry : payload.getSideInputsMap().entrySet()) {
try {
additionalInputs.put(
new TupleTag<>(sideInputEntry.getKey()),
rehydratedComponents.getPCollection(
protoTransform.getInputsOrThrow(sideInputEntry.getKey())));
} catch (IOException exc) {
throw new IllegalStateException(
String.format(
"Could not find input with name %s for %s transform",
sideInputEntry.getKey(), ParDo.class.getSimpleName()));
}
}
return additionalInputs;
}
@Override
public SdkFunctionSpec translateDoFn(SdkComponents newComponents) {
// TODO: re-register the environment with the new components
return payload.getDoFn();
}
@Override
public List<RunnerApi.Parameter> translateParameters() {
return MoreObjects.firstNonNull(
payload.getParametersList(), Collections.<RunnerApi.Parameter>emptyList());
}
@Override
public Map<String, SideInput> translateSideInputs(SdkComponents components) {
// TODO: re-register the PCollections and UDF environments
return MoreObjects.firstNonNull(
payload.getSideInputsMap(), Collections.<String, SideInput>emptyMap());
}
@Override
public Map<String, RunnerApi.StateSpec> translateStateSpecs(SdkComponents components) {
// TODO: re-register the coders
return MoreObjects.firstNonNull(
payload.getStateSpecsMap(), Collections.<String, RunnerApi.StateSpec>emptyMap());
}
@Override
public Map<String, RunnerApi.TimerSpec> translateTimerSpecs(SdkComponents newComponents) {
return MoreObjects.firstNonNull(
payload.getTimerSpecsMap(), Collections.<String, RunnerApi.TimerSpec>emptyMap());
}
@Override
public boolean isSplittable() {
return payload.getSplittable();
}
}
/** These methods drive to-proto translation from Java and from rehydrated ParDos. */
private interface ParDoLike {
SdkFunctionSpec translateDoFn(SdkComponents newComponents);
List<RunnerApi.Parameter> translateParameters();
Map<String, RunnerApi.SideInput> translateSideInputs(SdkComponents components);
Map<String, RunnerApi.StateSpec> translateStateSpecs(SdkComponents components)
throws IOException;
Map<String, RunnerApi.TimerSpec> translateTimerSpecs(SdkComponents newComponents);
boolean isSplittable();
}
public static ParDoPayload payloadForParDoLike(ParDoLike parDo, SdkComponents components)
throws IOException {
return ParDoPayload.newBuilder()
.setDoFn(parDo.translateDoFn(components))
.addAllParameters(parDo.translateParameters())
.putAllStateSpecs(parDo.translateStateSpecs(components))
.putAllTimerSpecs(parDo.translateTimerSpecs(components))
.putAllSideInputs(parDo.translateSideInputs(components))
.setSplittable(parDo.isSplittable())
.build();
}
}