blob: 12577450a78fc8f24f28e223d2ec8a654a0438af [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.twister2.translators.functions;
import edu.iu.dsc.tws.api.tset.TSetContext;
import edu.iu.dsc.tws.api.tset.fn.ComputeCollectorFunc;
import edu.iu.dsc.tws.api.tset.fn.RecordCollector;
import java.io.IOException;
import java.io.ObjectStreamException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Logger;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.runners.core.DoFnRunners;
import org.apache.beam.runners.core.StepContext;
import org.apache.beam.runners.core.construction.Environments;
import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.core.construction.RehydratedComponents;
import org.apache.beam.runners.core.construction.SdkComponents;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.construction.WindowingStrategyTranslation;
import org.apache.beam.runners.twister2.Twister2TranslationContext;
import org.apache.beam.runners.twister2.utils.NoOpStepContext;
import org.apache.beam.runners.twister2.utils.Twister2SideInputReader;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PortablePipelineOptions;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.util.DoFnWithExecutionInformation;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.InvalidProtocolBufferException;
/** DoFn function. */
@SuppressWarnings({
"rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
public class DoFnFunction<OutputT, InputT>
implements ComputeCollectorFunc<RawUnionValue, Iterator<WindowedValue<InputT>>> {
private static final Logger LOG = Logger.getLogger(DoFnFunction.class.getName());
private transient DoFn<InputT, OutputT> doFn;
private transient PipelineOptions pipelineOptions;
private static final long serialVersionUID = -5701440128544343353L;
private transient Coder<InputT> inputCoder;
private transient Map<TupleTag<?>, Coder<?>> outputCoders;
private transient WindowingStrategy<?, ?> windowingStrategy;
private transient Map<TupleTag<?>, WindowingStrategy<?, ?>> sideInputs;
private transient TupleTag<OutputT> mainOutput;
private transient Twister2SideInputReader sideInputReader;
private transient DoFnRunner<InputT, OutputT> doFnRunner;
private transient DoFnOutputManager outputManager;
private transient List<TupleTag<?>> sideOutputs;
private StepContext stepcontext;
private transient DoFnSchemaInformation doFnSchemaInformation;
private transient Map<TupleTag<?>, Integer> outputMap;
private transient Map<String, PCollectionView<?>> sideInputMapping;
private transient DoFnInvoker doFnInvoker;
// Byte arrays needed for serialization
private transient boolean isInitialized = false;
private transient RunnerApi.FunctionSpec doFnwithEx;
private byte[] doFnwithExBytes;
private byte[] coderBytes;
private Map<String, byte[]> outputCodersBytes;
private transient RunnerApi.MessageWithComponents windowStrategyProto;
private byte[] windowBytes;
private Map<String, byte[]> sideInputBytes;
private String serializedOptions;
private List<String> serializedSideOutputs;
private Map<String, Integer> serializedOutputMap;
public DoFnFunction() {
// non arg constructor needed for kryo
this.isInitialized = false;
}
public DoFnFunction(
Twister2TranslationContext context,
DoFn<InputT, OutputT> doFn,
Coder<InputT> inputCoder,
Map<TupleTag<?>, Coder<?>> outputCoders,
List<TupleTag<?>> sideOutputs,
WindowingStrategy<?, ?> windowingStrategy,
Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs,
TupleTag<OutputT> mainOutput,
DoFnSchemaInformation doFnSchemaInformation,
Map<TupleTag<?>, Integer> outputMap,
Map<String, PCollectionView<?>> sideInputMapping) {
this.doFn = doFn;
this.pipelineOptions = context.getOptions();
this.serializedOptions = new SerializablePipelineOptions(pipelineOptions).toString();
this.inputCoder = inputCoder;
this.outputCoders = outputCoders;
this.windowingStrategy = windowingStrategy;
this.sideInputs = convertToTuples(sideInputs);
this.mainOutput = mainOutput;
this.doFnSchemaInformation = doFnSchemaInformation;
this.sideOutputs = sideOutputs;
this.stepcontext = new NoOpStepContext();
this.outputMap = outputMap;
this.sideInputMapping = sideInputMapping;
outputManager = new DoFnOutputManager(this.outputMap);
prepareSerialization();
}
private Map<TupleTag<?>, WindowingStrategy<?, ?>> convertToTuples(
Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs) {
Map<TupleTag<?>, WindowingStrategy<?, ?>> temp = new HashMap<>();
for (Map.Entry<PCollectionView<?>, WindowingStrategy<?, ?>> entry : sideInputs.entrySet()) {
temp.put(entry.getKey().getTagInternal(), entry.getValue());
}
return temp;
}
/**
* get the tag id's of all the keys.
*
* @return A Set of String key values.
*/
public Set<String> getSideInputKeys() {
initTransient();
Set<String> keys = new HashSet<>();
for (TupleTag<?> view : sideInputs.keySet()) {
keys.add(view.getId());
}
return keys;
}
@Override
public void prepare(TSetContext context) {
initTransient();
sideInputReader = new Twister2SideInputReader(sideInputs, context);
outputManager.setup(mainOutput, sideOutputs);
doFnInvoker = DoFnInvokers.tryInvokeSetupFor(doFn, pipelineOptions);
doFnRunner =
DoFnRunners.simpleRunner(
pipelineOptions,
doFn,
sideInputReader,
outputManager,
mainOutput,
sideOutputs,
stepcontext,
inputCoder,
outputCoders,
windowingStrategy,
doFnSchemaInformation,
sideInputMapping);
}
@Override
public void compute(
Iterator<WindowedValue<InputT>> input, RecordCollector<RawUnionValue> output) {
try {
outputManager.clear();
doFnRunner.startBundle();
while (input.hasNext()) {
doFnRunner.processElement(input.next());
}
doFnRunner.finishBundle();
Iterator<RawUnionValue> outputs = outputManager.getOutputs();
while (outputs.hasNext()) {
output.collect(outputs.next());
}
} catch (final RuntimeException re) {
DoFnInvokers.invokerFor(doFn).invokeTeardown();
throw re;
}
}
@Override
public void close() {
Optional.ofNullable(doFnInvoker).ifPresent(DoFnInvoker::invokeTeardown);
}
private static class DoFnOutputManager implements DoFnRunners.OutputManager, Serializable {
// todo need to figure out how this class types are handled
private static final long serialVersionUID = 4967375172737408160L;
private transient List<RawUnionValue> outputs;
private transient Set<TupleTag<?>> outputTags;
private Map<TupleTag<?>, Integer> outputMap;
private DoFnOutputManager() {}
DoFnOutputManager(Map<TupleTag<?>, Integer> outputMap) {
this.outputMap = outputMap;
}
@Override
public <T> void output(TupleTag<T> outputTag, WindowedValue<T> output) {
if (outputTags.contains(outputTag)) {
outputs.add(new RawUnionValue(outputMap.get(outputTag), output));
}
}
void setup(TupleTag<?> mainOutput, List<TupleTag<?>> sideOutputs) {
outputs = new ArrayList<>();
outputTags = new HashSet<>();
outputTags.add(mainOutput);
outputTags.addAll(sideOutputs);
}
void clear() {
outputs.clear();
}
Iterator<RawUnionValue> getOutputs() {
return outputs.iterator();
}
}
protected Object readResolve() throws ObjectStreamException {
return this;
}
/**
* prepares the DoFnFunction class so it can be serialized properly. This involves using various
* protobuf's and byte arrays which are later converted back into the proper classes during
* deserialization.
*/
private void prepareSerialization() {
SdkComponents components = SdkComponents.create();
components.registerEnvironment(
Environments.createOrGetDefaultEnvironment(
pipelineOptions.as(PortablePipelineOptions.class)));
this.serializedOptions = new SerializablePipelineOptions(pipelineOptions).toString();
doFnwithEx =
ParDoTranslation.translateDoFn(
this.doFn, mainOutput, sideInputMapping, doFnSchemaInformation, components);
doFnwithExBytes = doFnwithEx.getPayload().toByteArray();
outputCodersBytes = new HashMap<>();
try {
coderBytes = SerializableUtils.serializeToByteArray(inputCoder);
windowStrategyProto =
WindowingStrategyTranslation.toMessageProto(windowingStrategy, components);
windowBytes = windowStrategyProto.toByteArray();
for (Map.Entry<TupleTag<?>, Coder<?>> entry : outputCoders.entrySet()) {
outputCodersBytes.put(
entry.getKey().getId(), SerializableUtils.serializeToByteArray(entry.getValue()));
}
sideInputBytes = new HashMap<>();
for (Map.Entry<TupleTag<?>, WindowingStrategy<?, ?>> entry : sideInputs.entrySet()) {
windowStrategyProto =
WindowingStrategyTranslation.toMessageProto(entry.getValue(), components);
sideInputBytes.put(entry.getKey().getId(), windowStrategyProto.toByteArray());
}
serializedSideOutputs = new ArrayList<>();
for (TupleTag<?> sideOutput : sideOutputs) {
serializedSideOutputs.add(sideOutput.getId());
}
serializedOutputMap = new HashMap<>();
for (Map.Entry<TupleTag<?>, Integer> entry : outputMap.entrySet()) {
serializedOutputMap.put(entry.getKey().getId(), entry.getValue());
}
} catch (IOException e) {
LOG.info(e.getMessage());
}
}
/**
* Method used to initialize the transient variables that were sent over as byte arrays or proto
* buffers.
*/
private void initTransient() {
if (isInitialized) {
return;
}
try {
SdkComponents components = SdkComponents.create();
pipelineOptions = new SerializablePipelineOptions(serializedOptions).get();
DoFnWithExecutionInformation doFnWithExecutionInformation =
(DoFnWithExecutionInformation)
SerializableUtils.deserializeFromByteArray(doFnwithExBytes, "Custom Coder Bytes");
this.doFn = (DoFn<InputT, OutputT>) doFnWithExecutionInformation.getDoFn();
this.mainOutput = (TupleTag<OutputT>) doFnWithExecutionInformation.getMainOutputTag();
this.sideInputMapping = doFnWithExecutionInformation.getSideInputMapping();
this.doFnSchemaInformation = doFnWithExecutionInformation.getSchemaInformation();
inputCoder =
(Coder<InputT>)
SerializableUtils.deserializeFromByteArray(coderBytes, "Custom Coder Bytes");
windowStrategyProto = RunnerApi.MessageWithComponents.parseFrom(windowBytes);
windowingStrategy =
(WindowingStrategy<?, ?>)
WindowingStrategyTranslation.fromProto(
windowStrategyProto.getWindowingStrategy(),
RehydratedComponents.forComponents(components.toComponents()));
sideInputs = new HashMap<>();
for (Map.Entry<String, byte[]> entry : sideInputBytes.entrySet()) {
windowStrategyProto = RunnerApi.MessageWithComponents.parseFrom(entry.getValue());
sideInputs.put(
new TupleTag<>(entry.getKey()),
WindowingStrategyTranslation.fromProto(
windowStrategyProto.getWindowingStrategy(),
RehydratedComponents.forComponents(components.toComponents())));
}
} catch (InvalidProtocolBufferException e) {
LOG.info(e.getMessage());
}
outputCoders = new HashMap<>();
for (Map.Entry<String, byte[]> entry : outputCodersBytes.entrySet()) {
outputCoders.put(
new TupleTag<>(entry.getKey()),
(Coder<?>)
SerializableUtils.deserializeFromByteArray(entry.getValue(), "Custom Coder Bytes"));
}
sideOutputs = new ArrayList<>();
for (String sideOutput : serializedSideOutputs) {
sideOutputs.add(new TupleTag<>(sideOutput));
}
outputMap = new HashMap<>();
for (Map.Entry<String, Integer> entry : serializedOutputMap.entrySet()) {
outputMap.put(new TupleTag<>(entry.getKey()), entry.getValue());
}
outputManager = new DoFnOutputManager(this.outputMap);
this.isInitialized = true;
}
}