blob: a4016e28d4ce998b5e2f2af9ce8c4b70bc64a0c3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.flink.translation.functions;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.construction.graph.SideInputReference;
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.runners.fnexecution.state.StateRequestHandlers.SideInputHandler;
import org.apache.beam.runners.fnexecution.state.StateRequestHandlers.SideInputHandlerFactory;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
/**
* {@link StateRequestHandler} that uses {@link org.apache.beam.runners.core.SideInputHandler} to
* access the Flink broadcast state that represents side inputs.
*/
public class FlinkStreamingSideInputHandlerFactory implements SideInputHandlerFactory {
// Map from side input id to global PCollection id.
private final Map<SideInputId, PCollectionView<?>> sideInputToCollection;
private final org.apache.beam.runners.core.SideInputHandler runnerHandler;
/**
* Creates a new state handler for the given stage. Note that this requires a traversal of the
* stage itself, so this should only be called once per stage rather than once per bundle.
*/
public static FlinkStreamingSideInputHandlerFactory forStage(
ExecutableStage stage,
Map<SideInputId, PCollectionView<?>> viewMapping,
org.apache.beam.runners.core.SideInputHandler runnerHandler) {
ImmutableMap.Builder<SideInputId, PCollectionView<?>> sideInputBuilder = ImmutableMap.builder();
for (SideInputReference sideInput : stage.getSideInputs()) {
SideInputId sideInputId =
SideInputId.newBuilder()
.setTransformId(sideInput.transform().getId())
.setLocalName(sideInput.localName())
.build();
sideInputBuilder.put(
sideInputId,
checkNotNull(
viewMapping.get(sideInputId),
"No side input for %s/%s",
sideInputId.getTransformId(),
sideInputId.getLocalName()));
}
FlinkStreamingSideInputHandlerFactory factory =
new FlinkStreamingSideInputHandlerFactory(sideInputBuilder.build(), runnerHandler);
return factory;
}
private FlinkStreamingSideInputHandlerFactory(
Map<SideInputId, PCollectionView<?>> sideInputToCollection,
org.apache.beam.runners.core.SideInputHandler runnerHandler) {
this.sideInputToCollection = sideInputToCollection;
this.runnerHandler = runnerHandler;
}
@Override
public <T, V, W extends BoundedWindow> SideInputHandler<V, W> forSideInput(
String transformId,
String sideInputId,
RunnerApi.FunctionSpec accessPattern,
Coder<T> elementCoder,
Coder<W> windowCoder) {
PCollectionView collectionNode =
sideInputToCollection.get(
SideInputId.newBuilder().setTransformId(transformId).setLocalName(sideInputId).build());
checkArgument(collectionNode != null, "No side input for %s/%s", transformId, sideInputId);
if (PTransformTranslation.ITERABLE_SIDE_INPUT.equals(accessPattern.getUrn())) {
@SuppressWarnings("unchecked") // T == V
Coder<V> outputCoder = (Coder<V>) elementCoder;
return forIterableSideInput(collectionNode, outputCoder);
} else if (PTransformTranslation.MULTIMAP_SIDE_INPUT.equals(accessPattern.getUrn())) {
@SuppressWarnings("unchecked") // T == KV<?, V>
KvCoder<?, V> kvCoder = (KvCoder<?, V>) elementCoder;
return forMultimapSideInput(collectionNode, kvCoder.getKeyCoder(), kvCoder.getValueCoder());
} else {
throw new IllegalArgumentException(
String.format("Unknown side input access pattern: %s", accessPattern));
}
}
private <T, W extends BoundedWindow> SideInputHandler<T, W> forIterableSideInput(
PCollectionView<?> collection, Coder<T> elementCoder) {
return new SideInputHandler<T, W>() {
@Override
public Iterable<T> get(byte[] key, W window) {
return checkNotNull(
(Iterable<T>) runnerHandler.getIterable(collection, window),
"Element processed by SDK before side input is ready");
}
@Override
public Coder<T> resultCoder() {
return elementCoder;
}
};
}
private <K, V, W extends BoundedWindow> SideInputHandler<V, W> forMultimapSideInput(
PCollectionView<?> collection, Coder<K> keyCoder, Coder<V> valueCoder) {
return new SideInputHandler<V, W>() {
@Override
public Iterable<V> get(byte[] key, W window) {
Iterable<KV<K, V>> values =
(Iterable<KV<K, V>>) runnerHandler.getIterable(collection, window);
ArrayList<V> result = new ArrayList<>();
// find values for the given key
for (KV<K, V> kv : values) {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
try {
keyCoder.encode(kv.getKey(), bos);
if (Arrays.equals(key, bos.toByteArray())) {
result.add(kv.getValue());
}
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
return result;
}
@Override
public Coder<V> resultCoder() {
return valueCoder;
}
};
}
}