blob: 2d22650efb44aa41eba00f94e1741ed39f388076 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.runners.core.StateNamespaces;
import org.apache.beam.runners.core.StateNamespaces.WindowNamespace;
import org.apache.beam.runners.core.StateTag;
import org.apache.beam.runners.core.StateTags;
import org.apache.beam.runners.core.TimerInternals.TimerData;
import org.apache.beam.runners.core.TimerInternals.TimerDataCoder;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
import org.apache.beam.sdk.coders.AtomicCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.MapCoder;
import org.apache.beam.sdk.coders.SetCoder;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.state.WatermarkHoldState;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.grpc.v1p13p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p13p1.com.google.protobuf.Parser;
import org.apache.beam.vendor.guava.v20_0.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Lists;
/** A class that handles streaming side inputs in a {@link DoFnRunner}. */
public class StreamingSideInputFetcher<InputT, W extends BoundedWindow> {
private StreamingModeExecutionContext.StreamingModeStepContext stepContext;
private Map<String, PCollectionView<?>> sideInputViews;
private final StateTag<BagState<WindowedValue<InputT>>> elementsAddr;
private final StateTag<BagState<TimerData>> timersAddr;
private final StateTag<WatermarkHoldState> watermarkHoldingAddr;
private final StateTag<ValueState<Map<W, Set<Windmill.GlobalDataRequest>>>> blockedMapAddr;
private Map<W, Set<Windmill.GlobalDataRequest>> blockedMap = null; // lazily initialized
private final Coder<W> mainWindowCoder;
public StreamingSideInputFetcher(
Iterable<PCollectionView<?>> views,
Coder<InputT> inputCoder,
WindowingStrategy<?, W> windowingStrategy,
StreamingModeExecutionContext.StreamingModeStepContext stepContext) {
this.stepContext = stepContext;
this.mainWindowCoder = windowingStrategy.getWindowFn().windowCoder();
this.sideInputViews = new HashMap<>();
for (PCollectionView<?> view : views) {
sideInputViews.put(view.getTagInternal().getId(), view);
}
this.blockedMapAddr = blockedMapAddr(mainWindowCoder);
this.elementsAddr =
StateTags.makeSystemTagInternal(
StateTags.bag("elem", WindowedValue.getFullCoder(inputCoder, mainWindowCoder)));
this.timersAddr =
StateTags.makeSystemTagInternal(StateTags.bag("timer", TimerDataCoder.of(mainWindowCoder)));
StateTag<WatermarkHoldState> watermarkTag =
StateTags.watermarkStateInternal(
"holdForSideinput", windowingStrategy.getTimestampCombiner());
this.watermarkHoldingAddr = StateTags.makeSystemTagInternal(watermarkTag);
}
@VisibleForTesting
static <W extends BoundedWindow>
StateTag<ValueState<Map<W, Set<GlobalDataRequest>>>> blockedMapAddr(
Coder<W> mainWindowCoder) {
return StateTags.value(
"blockedMap", MapCoder.of(mainWindowCoder, SetCoder.of(GlobalDataRequestCoder.of())));
}
/** Computes the set of main input windows for which all side inputs are ready and cached. */
public Set<W> getReadyWindows() {
Set<W> readyWindows = new HashSet<>();
for (Windmill.GlobalDataId id : stepContext.getSideInputNotifications()) {
if (sideInputViews.get(id.getTag()) == null) {
// Side input is for a different DoFn; ignore it.
continue;
}
for (Map.Entry<W, Set<Windmill.GlobalDataRequest>> entry : blockedMap().entrySet()) {
Set<Windmill.GlobalDataRequest> windowBlockedSet = entry.getValue();
Set<Windmill.GlobalDataRequest> found = new HashSet<>();
for (Windmill.GlobalDataRequest request : windowBlockedSet) {
if (id.equals(request.getDataId())) {
found.add(request);
}
}
windowBlockedSet.removeAll(found);
if (windowBlockedSet.isEmpty()) {
// Notifications were received for all side inputs for this window.
// Issue fetches for all the needed side inputs to make sure they are all present
// in the local cache. If not, note the side inputs as still being blocked.
W window = entry.getKey();
boolean allSideInputsCached = true;
for (PCollectionView<?> view : sideInputViews.values()) {
if (!stepContext.issueSideInputFetch(
view, window, StateFetcher.SideInputState.KNOWN_READY)) {
Windmill.GlobalDataRequest request = buildGlobalDataRequest(view, window);
stepContext.addBlockingSideInput(request);
windowBlockedSet.add(request);
allSideInputsCached = false;
}
}
if (allSideInputsCached) {
readyWindows.add(window);
}
}
}
}
return readyWindows;
}
public Iterable<BagState<WindowedValue<InputT>>> prefetchElements(Iterable<W> readyWindows) {
List<BagState<WindowedValue<InputT>>> elements = Lists.newArrayList();
for (W window : readyWindows) {
elements.add(elementBag(window).readLater());
}
return elements;
}
public void releaseBlockedWindows(Iterable<W> windows) {
for (W window : windows) {
WatermarkHoldState watermarkHold = watermarkHold(window);
watermarkHold.clear();
blockedMap().remove(window);
}
}
public Iterable<BagState<TimerData>> prefetchTimers(Iterable<W> readyWindows) {
List<BagState<TimerData>> timers = Lists.newArrayList();
for (W window : readyWindows) {
timers.add(timerBag(window).readLater());
}
return timers;
}
/** Compute the set of side inputs that are not yet ready for the given main input window. */
public boolean storeIfBlocked(WindowedValue<InputT> elem) {
@SuppressWarnings("unchecked")
W window = (W) Iterables.getOnlyElement(elem.getWindows());
Set<Windmill.GlobalDataRequest> blocked = blockedMap().get(window);
if (blocked == null) {
for (PCollectionView<?> view : sideInputViews.values()) {
if (!stepContext.issueSideInputFetch(view, window, StateFetcher.SideInputState.UNKNOWN)) {
if (blocked == null) {
blocked = new HashSet<>();
blockedMap().put(window, blocked);
}
blocked.add(buildGlobalDataRequest(view, window));
}
}
}
if (blocked != null) {
elementBag(window).add(elem);
watermarkHold(window).add(elem.getTimestamp());
stepContext.addBlockingSideInputs(blocked);
return true;
} else {
return false;
}
}
public boolean storeIfBlocked(TimerData timer) {
if (!(timer.getNamespace() instanceof WindowNamespace)) {
throw new IllegalArgumentException(
"Expected WindowNamespace, but was " + timer.getNamespace());
}
@SuppressWarnings("unchecked")
WindowNamespace<W> windowNamespace = (WindowNamespace<W>) timer.getNamespace();
W window = windowNamespace.getWindow();
boolean blocked = false;
for (PCollectionView<?> view : sideInputViews.values()) {
if (!stepContext.issueSideInputFetch(view, window, StateFetcher.SideInputState.UNKNOWN)) {
blocked = true;
}
}
if (blocked) {
timerBag(window).add(timer);
}
return blocked;
}
public void persist() {
if (blockedMap == null) {
return;
}
ValueState<Map<W, Set<Windmill.GlobalDataRequest>>> mapState =
stepContext.stateInternals().state(StateNamespaces.global(), blockedMapAddr);
if (blockedMap.isEmpty()) {
// Avoid storing the empty map so we don't leave unnecessary state behind from processing
// the key.
mapState.clear();
} else {
mapState.write(blockedMap);
}
blockedMap = null;
}
private Map<W, Set<Windmill.GlobalDataRequest>> blockedMap() {
if (blockedMap == null) {
blockedMap =
stepContext.stateInternals().state(StateNamespaces.global(), blockedMapAddr).read();
if (blockedMap == null) {
blockedMap = new HashMap<>();
}
}
return blockedMap;
}
@VisibleForTesting
Set<W> getBlockedWindows() {
return blockedMap().keySet();
}
@VisibleForTesting
BagState<WindowedValue<InputT>> elementBag(W window) {
return stepContext
.stateInternals()
.state(StateNamespaces.window(mainWindowCoder, window), elementsAddr);
}
@VisibleForTesting
WatermarkHoldState watermarkHold(W window) {
return stepContext
.stateInternals()
.state(StateNamespaces.window(mainWindowCoder, window), watermarkHoldingAddr);
}
@VisibleForTesting
BagState<TimerData> timerBag(W window) {
return stepContext
.stateInternals()
.state(StateNamespaces.window(mainWindowCoder, window), timersAddr);
}
private <SideWindowT extends BoundedWindow> Windmill.GlobalDataRequest buildGlobalDataRequest(
PCollectionView<?> view, BoundedWindow mainWindow) {
@SuppressWarnings("unchecked")
WindowingStrategy<?, SideWindowT> sideWindowStrategy =
(WindowingStrategy<?, SideWindowT>) view.getWindowingStrategyInternal();
WindowFn<?, SideWindowT> sideWindowFn = sideWindowStrategy.getWindowFn();
Coder<SideWindowT> sideInputWindowCoder = sideWindowFn.windowCoder();
SideWindowT sideInputWindow =
(SideWindowT) view.getWindowMappingFn().getSideInputWindow(mainWindow);
ByteString.Output windowStream = ByteString.newOutput();
try {
sideInputWindowCoder.encode(sideInputWindow, windowStream, Coder.Context.OUTER);
} catch (IOException e) {
throw new RuntimeException(e);
}
return Windmill.GlobalDataRequest.newBuilder()
.setDataId(
Windmill.GlobalDataId.newBuilder()
.setTag(view.getTagInternal().getId())
.setVersion(windowStream.toByteString())
.build())
.setExistenceWatermarkDeadline(
WindmillTimeUtils.harnessToWindmillTimestamp(
sideWindowStrategy.getTrigger().getWatermarkThatGuaranteesFiring(sideInputWindow)))
.build();
}
private static class GlobalDataRequestCoder extends AtomicCoder<GlobalDataRequest> {
private final Class<Windmill.GlobalDataRequest> protoMessageClass =
Windmill.GlobalDataRequest.class;
private transient Parser<Windmill.GlobalDataRequest> memoizedParser;
public static GlobalDataRequestCoder of() {
return new GlobalDataRequestCoder();
}
@Override
public Windmill.GlobalDataRequest decode(InputStream inStream) throws IOException {
return decode(inStream, Context.NESTED);
}
@Override
public Windmill.GlobalDataRequest decode(InputStream inStream, Context context)
throws IOException {
if (context.isWholeStream) {
return getParser().parseFrom(inStream);
} else {
return getParser().parseDelimitedFrom(inStream);
}
}
@Override
public void encode(Windmill.GlobalDataRequest value, OutputStream outStream)
throws IOException {
encode(value, outStream, Context.NESTED);
}
@Override
public void encode(Windmill.GlobalDataRequest value, OutputStream outStream, Context context)
throws IOException {
if (value == null) {
throw new CoderException("cannot encode a null " + protoMessageClass.getSimpleName());
}
if (context.isWholeStream) {
value.writeTo(outStream);
} else {
value.writeDelimitedTo(outStream);
}
}
private Parser<Windmill.GlobalDataRequest> getParser() {
if (memoizedParser == null) {
memoizedParser = Windmill.GlobalDataRequest.getDefaultInstance().getParserForType();
}
return memoizedParser;
}
}
}