blob: 26b0dfa65dbd77e3104f818413316d54ba299cc1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.fn.harness.state;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.runners.core.SideInputReader;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.function.ThrowingRunnable;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.ReadableState;
import org.apache.beam.sdk.state.ReadableStates;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.StateBinder;
import org.apache.beam.sdk.state.StateContext;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.state.WatermarkHoldState;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.util.CombineFnUtil;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
/** Provides access to side inputs and state via a {@link BeamFnStateClient}. */
public class FnApiStateAccessor implements SideInputReader, StateBinder {
private final PipelineOptions pipelineOptions;
private final Map<StateKey, Object> stateKeyObjectCache;
private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap;
private final BeamFnStateClient beamFnStateClient;
private final String ptransformId;
private final Supplier<String> processBundleInstructionId;
private final Collection<ThrowingRunnable> stateFinalizers;
private final Supplier<BoundedWindow> currentWindowSupplier;
private final Supplier<ByteString> encodedCurrentKeySupplier;
private final Supplier<ByteString> encodedCurrentWindowSupplier;
public FnApiStateAccessor(
PipelineOptions pipelineOptions,
String ptransformId,
Supplier<String> processBundleInstructionId,
Map<TupleTag<?>, SideInputSpec> sideInputSpecMap,
BeamFnStateClient beamFnStateClient,
Coder<?> keyCoder,
Coder<BoundedWindow> windowCoder,
Supplier<WindowedValue<?>> currentElementSupplier,
Supplier<BoundedWindow> currentWindowSupplier) {
this.pipelineOptions = pipelineOptions;
this.stateKeyObjectCache = Maps.newHashMap();
this.sideInputSpecMap = sideInputSpecMap;
this.beamFnStateClient = beamFnStateClient;
this.ptransformId = ptransformId;
this.processBundleInstructionId = processBundleInstructionId;
this.stateFinalizers = new ArrayList<>();
this.currentWindowSupplier = currentWindowSupplier;
this.encodedCurrentKeySupplier =
memoizeFunction(
currentElementSupplier,
element -> {
checkState(
element.getValue() instanceof KV,
"Accessing state in unkeyed context. Current element is not a KV: %s.",
element);
checkState(
keyCoder != null, "Accessing state in unkeyed context, no key coder available");
ByteString.Output encodedKeyOut = ByteString.newOutput();
try {
((Coder) keyCoder)
.encode(
((KV<?, ?>) element.getValue()).getKey(),
encodedKeyOut,
Coder.Context.NESTED);
} catch (IOException e) {
throw new IllegalStateException(e);
}
return encodedKeyOut.toByteString();
});
this.encodedCurrentWindowSupplier =
memoizeFunction(
currentWindowSupplier,
window -> {
ByteString.Output encodedWindowOut = ByteString.newOutput();
try {
windowCoder.encode(window, encodedWindowOut);
} catch (IOException e) {
throw new IllegalStateException(e);
}
return encodedWindowOut.toByteString();
});
}
private static <ArgT, ResultT> Supplier<ResultT> memoizeFunction(
Supplier<ArgT> arg, Function<ArgT, ResultT> f) {
return new Supplier<ResultT>() {
private ArgT memoizedArg;
private ResultT memoizedResult;
@Override
public ResultT get() {
ArgT currentArg = arg.get();
if (currentArg != memoizedArg) {
memoizedResult = f.apply(this.memoizedArg = currentArg);
}
return memoizedResult;
}
};
}
@Override
@Nullable
public <T> T get(PCollectionView<T> view, BoundedWindow window) {
TupleTag<?> tag = view.getTagInternal();
SideInputSpec sideInputSpec = sideInputSpecMap.get(tag);
checkArgument(sideInputSpec != null, "Attempting to access unknown side input %s.", view);
KvCoder<?, ?> kvCoder = (KvCoder) sideInputSpec.getCoder();
ByteString.Output encodedWindowOut = ByteString.newOutput();
try {
sideInputSpec
.getWindowCoder()
.encode(sideInputSpec.getWindowMappingFn().getSideInputWindow(window), encodedWindowOut);
} catch (IOException e) {
throw new IllegalStateException(e);
}
ByteString encodedWindow = encodedWindowOut.toByteString();
StateKey.Builder cacheKeyBuilder = StateKey.newBuilder();
cacheKeyBuilder
.getMultimapSideInputBuilder()
.setTransformId(ptransformId)
.setSideInputId(tag.getId())
.setWindow(encodedWindow);
return (T)
stateKeyObjectCache.computeIfAbsent(
cacheKeyBuilder.build(),
key ->
sideInputSpec
.getViewFn()
.apply(
new MultimapSideInput<>(
beamFnStateClient,
processBundleInstructionId.get(),
ptransformId,
tag.getId(),
encodedWindow,
kvCoder.getKeyCoder(),
kvCoder.getValueCoder())));
}
@Override
public <T> boolean contains(PCollectionView<T> view) {
return sideInputSpecMap.containsKey(view.getTagInternal());
}
@Override
public boolean isEmpty() {
return sideInputSpecMap.isEmpty();
}
@Override
public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Coder<T> coder) {
return (ValueState<T>)
stateKeyObjectCache.computeIfAbsent(
createBagUserStateKey(id),
new Function<StateKey, Object>() {
@Override
public Object apply(StateKey key) {
return new ValueState<T>() {
private final BagUserState<T> impl = createBagUserState(id, coder);
@Override
public void clear() {
impl.clear();
}
@Override
public void write(T input) {
impl.clear();
impl.append(input);
}
@Override
public T read() {
Iterator<T> value = impl.get().iterator();
if (value.hasNext()) {
return value.next();
} else {
return null;
}
}
@Override
public ValueState<T> readLater() {
// TODO: Support prefetching.
return this;
}
};
}
});
}
@Override
public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T> elemCoder) {
return (BagState<T>)
stateKeyObjectCache.computeIfAbsent(
createBagUserStateKey(id),
new Function<StateKey, Object>() {
@Override
public Object apply(StateKey key) {
return new BagState<T>() {
private final BagUserState<T> impl = createBagUserState(id, elemCoder);
@Override
public void add(T value) {
impl.append(value);
}
@Override
public ReadableState<Boolean> isEmpty() {
return new ReadableState<Boolean>() {
@Nullable
@Override
public Boolean read() {
return !impl.get().iterator().hasNext();
}
@Override
public ReadableState<Boolean> readLater() {
return this;
}
};
}
@Override
public Iterable<T> read() {
return impl.get();
}
@Override
public BagState<T> readLater() {
// TODO: Support prefetching.
return this;
}
@Override
public void clear() {
impl.clear();
}
};
}
});
}
@Override
public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) {
throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
}
@Override
public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
String id,
StateSpec<MapState<KeyT, ValueT>> spec,
Coder<KeyT> mapKeyCoder,
Coder<ValueT> mapValueCoder) {
throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
}
@Override
public <ElementT, AccumT, ResultT> CombiningState<ElementT, AccumT, ResultT> bindCombining(
String id,
StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec,
Coder<AccumT> accumCoder,
CombineFn<ElementT, AccumT, ResultT> combineFn) {
return (CombiningState<ElementT, AccumT, ResultT>)
stateKeyObjectCache.computeIfAbsent(
createBagUserStateKey(id),
new Function<StateKey, Object>() {
@Override
public Object apply(StateKey key) {
// TODO: Support squashing accumulators depending on whether we know of all
// remote accumulators and local accumulators or just local accumulators.
return new CombiningState<ElementT, AccumT, ResultT>() {
private final BagUserState<AccumT> impl = createBagUserState(id, accumCoder);
@Override
public AccumT getAccum() {
Iterator<AccumT> iterator = impl.get().iterator();
if (iterator.hasNext()) {
return iterator.next();
}
return combineFn.createAccumulator();
}
@Override
public void addAccum(AccumT accum) {
Iterator<AccumT> iterator = impl.get().iterator();
// Only merge if there was a prior value
if (iterator.hasNext()) {
accum = combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum));
// Since there was a prior value, we need to clear.
impl.clear();
}
impl.append(accum);
}
@Override
public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
return combineFn.mergeAccumulators(accumulators);
}
@Override
public CombiningState<ElementT, AccumT, ResultT> readLater() {
return this;
}
@Override
public ResultT read() {
Iterator<AccumT> iterator = impl.get().iterator();
if (iterator.hasNext()) {
return combineFn.extractOutput(iterator.next());
}
return combineFn.defaultValue();
}
@Override
public void add(ElementT value) {
AccumT newAccumulator = combineFn.addInput(getAccum(), value);
impl.clear();
impl.append(newAccumulator);
}
@Override
public ReadableState<Boolean> isEmpty() {
return ReadableStates.immediate(!impl.get().iterator().hasNext());
}
@Override
public void clear() {
impl.clear();
}
};
}
});
}
@Override
public <ElementT, AccumT, ResultT>
CombiningState<ElementT, AccumT, ResultT> bindCombiningWithContext(
String id,
StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec,
Coder<AccumT> accumCoder,
CombineFnWithContext<ElementT, AccumT, ResultT> combineFn) {
return (CombiningState<ElementT, AccumT, ResultT>)
stateKeyObjectCache.computeIfAbsent(
createBagUserStateKey(id),
key ->
bindCombining(
id,
spec,
accumCoder,
CombineFnUtil.bindContext(
combineFn,
new StateContext<BoundedWindow>() {
@Override
public PipelineOptions getPipelineOptions() {
return pipelineOptions;
}
@Override
public <T> T sideInput(PCollectionView<T> view) {
return get(view, currentWindowSupplier.get());
}
@Override
public BoundedWindow window() {
return currentWindowSupplier.get();
}
})));
}
/**
* @deprecated The Fn API has no plans to implement WatermarkHoldState as of this writing and is
* waiting on resolution of BEAM-2535.
*/
@Override
@Deprecated
public WatermarkHoldState bindWatermark(
String id, StateSpec<WatermarkHoldState> spec, TimestampCombiner timestampCombiner) {
throw new UnsupportedOperationException("WatermarkHoldState is unsupported by the Fn API.");
}
private <T> BagUserState<T> createBagUserState(String stateId, Coder<T> valueCoder) {
BagUserState<T> rval =
new BagUserState<>(
beamFnStateClient,
processBundleInstructionId.get(),
ptransformId,
stateId,
encodedCurrentWindowSupplier.get(),
encodedCurrentKeySupplier.get(),
valueCoder);
stateFinalizers.add(rval::asyncClose);
return rval;
}
private StateKey createBagUserStateKey(String stateId) {
StateKey.Builder builder = StateKey.newBuilder();
builder
.getBagUserStateBuilder()
.setWindow(encodedCurrentWindowSupplier.get())
.setKey(encodedCurrentKeySupplier.get())
.setTransformId(ptransformId)
.setUserStateId(stateId);
return builder.build();
}
public void finalizeState() {
// Persist all dirty state cells
try {
for (ThrowingRunnable runnable : stateFinalizers) {
runnable.run();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException(e);
} catch (Exception e) {
throw new IllegalStateException(e);
}
}
}