| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.fn.harness.state; |
| |
| import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; |
| import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; |
| |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.Collection; |
| import java.util.Iterator; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.function.Function; |
| import java.util.function.Supplier; |
| import org.apache.beam.fn.harness.Cache; |
| import org.apache.beam.fn.harness.Caches; |
| import org.apache.beam.model.fnexecution.v1.BeamFnApi; |
| import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest.CacheToken; |
| import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey; |
| import org.apache.beam.runners.core.SideInputReader; |
| import org.apache.beam.sdk.coders.Coder; |
| import org.apache.beam.sdk.coders.KvCoder; |
| import org.apache.beam.sdk.coders.VoidCoder; |
| import org.apache.beam.sdk.function.ThrowingRunnable; |
| import org.apache.beam.sdk.options.PipelineOptions; |
| import org.apache.beam.sdk.state.BagState; |
| import org.apache.beam.sdk.state.CombiningState; |
| import org.apache.beam.sdk.state.MapState; |
| import org.apache.beam.sdk.state.OrderedListState; |
| import org.apache.beam.sdk.state.ReadableState; |
| import org.apache.beam.sdk.state.ReadableStates; |
| import org.apache.beam.sdk.state.SetState; |
| import org.apache.beam.sdk.state.StateBinder; |
| import org.apache.beam.sdk.state.StateContext; |
| import org.apache.beam.sdk.state.StateSpec; |
| import org.apache.beam.sdk.state.ValueState; |
| import org.apache.beam.sdk.state.WatermarkHoldState; |
| import org.apache.beam.sdk.transforms.Combine.CombineFn; |
| import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; |
| import org.apache.beam.sdk.transforms.Materializations; |
| import org.apache.beam.sdk.transforms.windowing.BoundedWindow; |
| import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; |
| import org.apache.beam.sdk.util.ByteStringOutputStream; |
| import org.apache.beam.sdk.util.CombineFnUtil; |
| import org.apache.beam.sdk.values.PCollectionView; |
| import org.apache.beam.sdk.values.TupleTag; |
| import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; |
| import org.checkerframework.checker.nullness.qual.Nullable; |
| |
| /** Provides access to side inputs and state via a {@link BeamFnStateClient}. */ |
| @SuppressWarnings({ |
| "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) |
| "nullness" // TODO(https://github.com/apache/beam/issues/20497) |
| }) |
| public class FnApiStateAccessor<K> implements SideInputReader, StateBinder { |
| private final PipelineOptions pipelineOptions; |
| private final Map<StateKey, Object> stateKeyObjectCache; |
| private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap; |
| private final BeamFnStateClient beamFnStateClient; |
| private final String ptransformId; |
| private final Supplier<String> processBundleInstructionId; |
| private final Supplier<List<BeamFnApi.ProcessBundleRequest.CacheToken>> cacheTokens; |
| private final Supplier<Cache<?, ?>> bundleCache; |
| private final Cache<?, ?> processWideCache; |
| private final Collection<ThrowingRunnable> stateFinalizers; |
| |
| private final Supplier<BoundedWindow> currentWindowSupplier; |
| |
| private final Supplier<ByteString> encodedCurrentKeySupplier; |
| private final Supplier<ByteString> encodedCurrentWindowSupplier; |
| |
| public FnApiStateAccessor( |
| PipelineOptions pipelineOptions, |
| String ptransformId, |
| Supplier<String> processBundleInstructionId, |
| Supplier<List<CacheToken>> cacheTokens, |
| Supplier<Cache<?, ?>> bundleCache, |
| Cache<?, ?> processWideCache, |
| Map<TupleTag<?>, SideInputSpec> sideInputSpecMap, |
| BeamFnStateClient beamFnStateClient, |
| Coder<K> keyCoder, |
| Coder<BoundedWindow> windowCoder, |
| Supplier<K> currentKeySupplier, |
| Supplier<BoundedWindow> currentWindowSupplier) { |
| this.pipelineOptions = pipelineOptions; |
| this.stateKeyObjectCache = Maps.newHashMap(); |
| this.sideInputSpecMap = sideInputSpecMap; |
| this.beamFnStateClient = beamFnStateClient; |
| this.ptransformId = ptransformId; |
| this.processBundleInstructionId = processBundleInstructionId; |
| this.cacheTokens = cacheTokens; |
| this.bundleCache = bundleCache; |
| this.processWideCache = processWideCache; |
| this.stateFinalizers = new ArrayList<>(); |
| this.currentWindowSupplier = currentWindowSupplier; |
| this.encodedCurrentKeySupplier = |
| memoizeFunction( |
| currentKeySupplier, |
| key -> { |
| checkState( |
| keyCoder != null, "Accessing state in unkeyed context, no key coder available"); |
| |
| ByteStringOutputStream encodedKeyOut = new ByteStringOutputStream(); |
| try { |
| ((Coder) keyCoder).encode(key, encodedKeyOut, Coder.Context.NESTED); |
| } catch (IOException e) { |
| throw new IllegalStateException(e); |
| } |
| return encodedKeyOut.toByteString(); |
| }); |
| |
| this.encodedCurrentWindowSupplier = |
| memoizeFunction( |
| currentWindowSupplier, |
| window -> { |
| ByteStringOutputStream encodedWindowOut = new ByteStringOutputStream(); |
| try { |
| windowCoder.encode(window, encodedWindowOut); |
| } catch (IOException e) { |
| throw new IllegalStateException(e); |
| } |
| return encodedWindowOut.toByteString(); |
| }); |
| } |
| |
| private static <ArgT, ResultT> Supplier<ResultT> memoizeFunction( |
| Supplier<ArgT> arg, Function<ArgT, ResultT> f) { |
| return new Supplier<ResultT>() { |
| private ArgT memoizedArg; |
| private ResultT memoizedResult; |
| private boolean initialized; |
| |
| @Override |
| public ResultT get() { |
| ArgT currentArg = arg.get(); |
| if (currentArg != memoizedArg || !initialized) { |
| memoizedResult = f.apply(this.memoizedArg = currentArg); |
| initialized = true; |
| } |
| return memoizedResult; |
| } |
| }; |
| } |
| |
| @Override |
| public @Nullable <T> T get(PCollectionView<T> view, BoundedWindow window) { |
| TupleTag<?> tag = view.getTagInternal(); |
| |
| SideInputSpec sideInputSpec = sideInputSpecMap.get(tag); |
| checkArgument(sideInputSpec != null, "Attempting to access unknown side input %s.", view); |
| |
| ByteStringOutputStream encodedWindowOut = new ByteStringOutputStream(); |
| try { |
| sideInputSpec |
| .getWindowCoder() |
| .encode(sideInputSpec.getWindowMappingFn().getSideInputWindow(window), encodedWindowOut); |
| } catch (IOException e) { |
| throw new IllegalStateException(e); |
| } |
| ByteString encodedWindow = encodedWindowOut.toByteString(); |
| StateKey.Builder cacheKeyBuilder = StateKey.newBuilder(); |
| |
| switch (sideInputSpec.getAccessPattern()) { |
| case Materializations.ITERABLE_MATERIALIZATION_URN: |
| cacheKeyBuilder |
| .getIterableSideInputBuilder() |
| .setTransformId(ptransformId) |
| .setSideInputId(tag.getId()) |
| .setWindow(encodedWindow); |
| break; |
| |
| case Materializations.MULTIMAP_MATERIALIZATION_URN: |
| checkState( |
| sideInputSpec.getCoder() instanceof KvCoder, |
| "Expected %s but received %s.", |
| KvCoder.class, |
| sideInputSpec.getCoder().getClass()); |
| cacheKeyBuilder |
| .getMultimapKeysSideInputBuilder() |
| .setTransformId(ptransformId) |
| .setSideInputId(tag.getId()) |
| .setWindow(encodedWindow); |
| break; |
| |
| default: |
| throw new IllegalStateException( |
| String.format( |
| "This SDK is only capable of dealing with %s materializations " |
| + "but was asked to handle %s for PCollectionView with tag %s.", |
| ImmutableList.of( |
| Materializations.ITERABLE_MATERIALIZATION_URN, |
| Materializations.MULTIMAP_MATERIALIZATION_URN), |
| sideInputSpec.getAccessPattern(), |
| tag)); |
| } |
| return (T) |
| stateKeyObjectCache.computeIfAbsent( |
| cacheKeyBuilder.build(), |
| key -> { |
| switch (sideInputSpec.getAccessPattern()) { |
| case Materializations.ITERABLE_MATERIALIZATION_URN: |
| return sideInputSpec |
| .getViewFn() |
| .apply( |
| new IterableSideInput<>( |
| getCacheFor(key), |
| beamFnStateClient, |
| processBundleInstructionId.get(), |
| key, |
| sideInputSpec.getCoder())); |
| case Materializations.MULTIMAP_MATERIALIZATION_URN: |
| return sideInputSpec |
| .getViewFn() |
| .apply( |
| new MultimapSideInput<>( |
| getCacheFor(key), |
| beamFnStateClient, |
| processBundleInstructionId.get(), |
| key, |
| ((KvCoder) sideInputSpec.getCoder()).getKeyCoder(), |
| ((KvCoder) sideInputSpec.getCoder()).getValueCoder())); |
| default: |
| throw new IllegalStateException( |
| String.format( |
| "This SDK is only capable of dealing with %s materializations " |
| + "but was asked to handle %s for PCollectionView with tag %s.", |
| ImmutableList.of( |
| Materializations.ITERABLE_MATERIALIZATION_URN, |
| Materializations.MULTIMAP_MATERIALIZATION_URN), |
| sideInputSpec.getAccessPattern(), |
| tag)); |
| } |
| }); |
| } |
| |
| @Override |
| public <T> boolean contains(PCollectionView<T> view) { |
| return sideInputSpecMap.containsKey(view.getTagInternal()); |
| } |
| |
| @Override |
| public boolean isEmpty() { |
| return sideInputSpecMap.isEmpty(); |
| } |
| |
| @Override |
| public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Coder<T> coder) { |
| return (ValueState<T>) |
| stateKeyObjectCache.computeIfAbsent( |
| createBagUserStateKey(id), |
| new Function<StateKey, Object>() { |
| @Override |
| public Object apply(StateKey key) { |
| return new ValueState<T>() { |
| private final BagUserState<T> impl = createBagUserState(key, coder); |
| |
| @Override |
| public void clear() { |
| impl.clear(); |
| } |
| |
| @Override |
| public void write(T input) { |
| impl.clear(); |
| impl.append(input); |
| } |
| |
| @Override |
| public T read() { |
| Iterator<T> value = impl.get().iterator(); |
| if (value.hasNext()) { |
| return value.next(); |
| } else { |
| return null; |
| } |
| } |
| |
| @Override |
| public ValueState<T> readLater() { |
| impl.get().prefetch(); |
| return this; |
| } |
| }; |
| } |
| }); |
| } |
| |
| @Override |
| public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T> elemCoder) { |
| return (BagState<T>) |
| stateKeyObjectCache.computeIfAbsent( |
| createBagUserStateKey(id), |
| new Function<StateKey, Object>() { |
| @Override |
| public Object apply(StateKey key) { |
| return new BagState<T>() { |
| private final BagUserState<T> impl = createBagUserState(key, elemCoder); |
| |
| @Override |
| public void add(T value) { |
| impl.append(value); |
| } |
| |
| @Override |
| public ReadableState<Boolean> isEmpty() { |
| return new ReadableState<Boolean>() { |
| @Override |
| public @Nullable Boolean read() { |
| return !impl.get().iterator().hasNext(); |
| } |
| |
| @Override |
| public ReadableState<Boolean> readLater() { |
| return this; |
| } |
| }; |
| } |
| |
| @Override |
| public Iterable<T> read() { |
| return impl.get(); |
| } |
| |
| @Override |
| public BagState<T> readLater() { |
| impl.get().prefetch(); |
| return this; |
| } |
| |
| @Override |
| public void clear() { |
| impl.clear(); |
| } |
| }; |
| } |
| }); |
| } |
| |
| @Override |
| public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) { |
| return (SetState<T>) |
| stateKeyObjectCache.computeIfAbsent( |
| createMultimapKeysUserStateKey(id), |
| new Function<StateKey, Object>() { |
| @Override |
| public Object apply(StateKey key) { |
| return new SetState<T>() { |
| private final MultimapUserState<T, Void> impl = |
| createMultimapUserState(key, elemCoder, VoidCoder.of()); |
| |
| @Override |
| public void clear() { |
| impl.clear(); |
| } |
| |
| @Override |
| public ReadableState<Boolean> contains(T t) { |
| return new ReadableState<Boolean>() { |
| @Override |
| public Boolean read() { |
| return !Iterables.isEmpty(impl.get(t)); |
| } |
| |
| @Override |
| public ReadableState<Boolean> readLater() { |
| impl.get(t).prefetch(); |
| return this; |
| } |
| }; |
| } |
| |
| @Override |
| public ReadableState<Boolean> addIfAbsent(T t) { |
| boolean isEmpty = Iterables.isEmpty(impl.get(t)); |
| if (isEmpty) { |
| impl.put(t, null); |
| } |
| return ReadableStates.immediate(isEmpty); |
| } |
| |
| @Override |
| public void remove(T t) { |
| impl.remove(t); |
| } |
| |
| @Override |
| public void add(T value) { |
| impl.remove(value); |
| impl.put(value, null); |
| } |
| |
| @Override |
| public ReadableState<Boolean> isEmpty() { |
| return new ReadableState<Boolean>() { |
| @Override |
| public Boolean read() { |
| return Iterables.isEmpty(impl.keys()); |
| } |
| |
| @Override |
| public ReadableState<Boolean> readLater() { |
| impl.keys().prefetch(); |
| return this; |
| } |
| }; |
| } |
| |
| @Override |
| public Iterable<T> read() { |
| return impl.keys(); |
| } |
| |
| @Override |
| public SetState<T> readLater() { |
| impl.keys().prefetch(); |
| return this; |
| } |
| }; |
| } |
| }); |
| } |
| |
| @Override |
| public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( |
| String id, |
| StateSpec<MapState<KeyT, ValueT>> spec, |
| Coder<KeyT> mapKeyCoder, |
| Coder<ValueT> mapValueCoder) { |
| return (MapState<KeyT, ValueT>) |
| stateKeyObjectCache.computeIfAbsent( |
| createMultimapKeysUserStateKey(id), |
| new Function<StateKey, Object>() { |
| @Override |
| public Object apply(StateKey key) { |
| return new MapState<KeyT, ValueT>() { |
| private final MultimapUserState<KeyT, ValueT> impl = |
| createMultimapUserState(key, mapKeyCoder, mapValueCoder); |
| |
| @Override |
| public void clear() { |
| impl.clear(); |
| } |
| |
| @Override |
| public void put(KeyT key, ValueT value) { |
| impl.remove(key); |
| impl.put(key, value); |
| } |
| |
| @Override |
| public ReadableState<ValueT> computeIfAbsent( |
| KeyT key, Function<? super KeyT, ? extends ValueT> mappingFunction) { |
| Iterable<ValueT> values = impl.get(key); |
| if (Iterables.isEmpty(values)) { |
| impl.put(key, mappingFunction.apply(key)); |
| } |
| return ReadableStates.immediate(Iterables.getOnlyElement(values, null)); |
| } |
| |
| @Override |
| public void remove(KeyT key) { |
| impl.remove(key); |
| } |
| |
| @Override |
| public ReadableState<ValueT> get(KeyT key) { |
| return getOrDefault(key, null); |
| } |
| |
| @Override |
| public ReadableState<ValueT> getOrDefault( |
| KeyT key, @Nullable ValueT defaultValue) { |
| return new ReadableState<ValueT>() { |
| @Override |
| public @Nullable ValueT read() { |
| Iterable<ValueT> values = impl.get(key); |
| return Iterables.getOnlyElement(values, defaultValue); |
| } |
| |
| @Override |
| public ReadableState<ValueT> readLater() { |
| impl.get(key).prefetch(); |
| return this; |
| } |
| }; |
| } |
| |
| @Override |
| public ReadableState<Iterable<KeyT>> keys() { |
| return new ReadableState<Iterable<KeyT>>() { |
| @Override |
| public Iterable<KeyT> read() { |
| return impl.keys(); |
| } |
| |
| @Override |
| public ReadableState<Iterable<KeyT>> readLater() { |
| impl.keys().prefetch(); |
| return this; |
| } |
| }; |
| } |
| |
| @Override |
| public ReadableState<Iterable<ValueT>> values() { |
| return new ReadableState<Iterable<ValueT>>() { |
| @Override |
| public Iterable<ValueT> read() { |
| return Iterables.transform(entries().read(), e -> e.getValue()); |
| } |
| |
| @Override |
| public ReadableState<Iterable<ValueT>> readLater() { |
| entries().readLater(); |
| return this; |
| } |
| }; |
| } |
| |
| @Override |
| public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> entries() { |
| return new ReadableState<Iterable<Map.Entry<KeyT, ValueT>>>() { |
| @Override |
| public Iterable<Map.Entry<KeyT, ValueT>> read() { |
| Iterable<KeyT> keys = keys().read(); |
| return Iterables.transform( |
| keys, key -> Maps.immutableEntry(key, get(key).read())); |
| } |
| |
| @Override |
| public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> readLater() { |
| // Start prefetching the keys. We would need to block to start prefetching |
| // the values. |
| keys().readLater(); |
| return this; |
| } |
| }; |
| } |
| |
| @Override |
| public ReadableState<Boolean> isEmpty() { |
| return new ReadableState<Boolean>() { |
| @Override |
| public Boolean read() { |
| return Iterables.isEmpty(keys().read()); |
| } |
| |
| @Override |
| public ReadableState<Boolean> readLater() { |
| keys().readLater(); |
| return this; |
| } |
| }; |
| } |
| }; |
| } |
| }); |
| } |
| |
| @Override |
| public <T> OrderedListState<T> bindOrderedList( |
| String id, StateSpec<OrderedListState<T>> spec, Coder<T> elemCoder) { |
| throw new UnsupportedOperationException( |
| "TODO: Add support for a sorted-list state to the Fn API."); |
| } |
| |
| @Override |
| public <ElementT, AccumT, ResultT> CombiningState<ElementT, AccumT, ResultT> bindCombining( |
| String id, |
| StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec, |
| Coder<AccumT> accumCoder, |
| CombineFn<ElementT, AccumT, ResultT> combineFn) { |
| return (CombiningState<ElementT, AccumT, ResultT>) |
| stateKeyObjectCache.computeIfAbsent( |
| createBagUserStateKey(id), |
| new Function<StateKey, Object>() { |
| @Override |
| public Object apply(StateKey key) { |
| // TODO: Support squashing accumulators depending on whether we know of all |
| // remote accumulators and local accumulators or just local accumulators. |
| return new CombiningState<ElementT, AccumT, ResultT>() { |
| private final BagUserState<AccumT> impl = createBagUserState(key, accumCoder); |
| |
| @Override |
| public AccumT getAccum() { |
| Iterator<AccumT> iterator = impl.get().iterator(); |
| if (iterator.hasNext()) { |
| return iterator.next(); |
| } |
| return combineFn.createAccumulator(); |
| } |
| |
| @Override |
| public void addAccum(AccumT accum) { |
| Iterator<AccumT> iterator = impl.get().iterator(); |
| |
| // Only merge if there was a prior value |
| if (iterator.hasNext()) { |
| accum = combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum)); |
| // Since there was a prior value, we need to clear. |
| impl.clear(); |
| } |
| |
| impl.append(accum); |
| } |
| |
| @Override |
| public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { |
| return combineFn.mergeAccumulators(accumulators); |
| } |
| |
| @Override |
| public CombiningState<ElementT, AccumT, ResultT> readLater() { |
| impl.get().prefetch(); |
| return this; |
| } |
| |
| @Override |
| public ResultT read() { |
| Iterator<AccumT> iterator = impl.get().iterator(); |
| if (iterator.hasNext()) { |
| return combineFn.extractOutput(iterator.next()); |
| } |
| return combineFn.defaultValue(); |
| } |
| |
| @Override |
| public void add(ElementT value) { |
| AccumT newAccumulator = combineFn.addInput(getAccum(), value); |
| impl.clear(); |
| impl.append(newAccumulator); |
| } |
| |
| @Override |
| public ReadableState<Boolean> isEmpty() { |
| return new ReadableState<Boolean>() { |
| @Override |
| public @Nullable Boolean read() { |
| return !impl.get().iterator().hasNext(); |
| } |
| |
| @Override |
| public ReadableState<Boolean> readLater() { |
| impl.get().prefetch(); |
| return this; |
| } |
| }; |
| } |
| |
| @Override |
| public void clear() { |
| impl.clear(); |
| } |
| }; |
| } |
| }); |
| } |
| |
| @Override |
| public <ElementT, AccumT, ResultT> |
| CombiningState<ElementT, AccumT, ResultT> bindCombiningWithContext( |
| String id, |
| StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec, |
| Coder<AccumT> accumCoder, |
| CombineFnWithContext<ElementT, AccumT, ResultT> combineFn) { |
| return (CombiningState<ElementT, AccumT, ResultT>) |
| stateKeyObjectCache.computeIfAbsent( |
| createBagUserStateKey(id), |
| key -> |
| bindCombining( |
| id, |
| spec, |
| accumCoder, |
| CombineFnUtil.bindContext( |
| combineFn, |
| new StateContext<BoundedWindow>() { |
| @Override |
| public PipelineOptions getPipelineOptions() { |
| return pipelineOptions; |
| } |
| |
| @Override |
| public <T> T sideInput(PCollectionView<T> view) { |
| return get(view, currentWindowSupplier.get()); |
| } |
| |
| @Override |
| public BoundedWindow window() { |
| return currentWindowSupplier.get(); |
| } |
| }))); |
| } |
| |
| /** |
| * @deprecated The Fn API has no plans to implement WatermarkHoldState as of this writing and is |
| * waiting on resolution of BEAM-2535. |
| */ |
| @Override |
| @Deprecated |
| public WatermarkHoldState bindWatermark( |
| String id, StateSpec<WatermarkHoldState> spec, TimestampCombiner timestampCombiner) { |
| throw new UnsupportedOperationException("WatermarkHoldState is unsupported by the Fn API."); |
| } |
| |
| private Cache<?, ?> getCacheFor(StateKey stateKey) { |
| switch (stateKey.getTypeCase()) { |
| case BAG_USER_STATE: |
| for (CacheToken token : cacheTokens.get()) { |
| if (!token.hasUserState()) { |
| continue; |
| } |
| return Caches.subCache(processWideCache, token, stateKey); |
| } |
| break; |
| case MULTIMAP_KEYS_USER_STATE: |
| for (CacheToken token : cacheTokens.get()) { |
| if (!token.hasUserState()) { |
| continue; |
| } |
| return Caches.subCache(processWideCache, token, stateKey); |
| } |
| break; |
| case ITERABLE_SIDE_INPUT: |
| for (CacheToken token : cacheTokens.get()) { |
| if (!token.hasSideInput()) { |
| continue; |
| } |
| if (stateKey |
| .getIterableSideInput() |
| .getTransformId() |
| .equals(token.getSideInput().getTransformId()) |
| && stateKey |
| .getIterableSideInput() |
| .getSideInputId() |
| .equals(token.getSideInput().getSideInputId())) { |
| return Caches.subCache(processWideCache, token, stateKey); |
| } |
| } |
| break; |
| case MULTIMAP_KEYS_SIDE_INPUT: |
| for (CacheToken token : cacheTokens.get()) { |
| if (!token.hasSideInput()) { |
| continue; |
| } |
| if (stateKey |
| .getMultimapKeysSideInput() |
| .getTransformId() |
| .equals(token.getSideInput().getTransformId()) |
| && stateKey |
| .getMultimapKeysSideInput() |
| .getSideInputId() |
| .equals(token.getSideInput().getSideInputId())) { |
| return Caches.subCache(processWideCache, token, stateKey); |
| } |
| } |
| break; |
| default: |
| throw new IllegalStateException( |
| String.format("Unknown state key type requested %s.", stateKey)); |
| } |
| // The default is to use the bundle cache. |
| return Caches.subCache(bundleCache.get(), stateKey); |
| } |
| |
| private <T> BagUserState<T> createBagUserState(StateKey stateKey, Coder<T> valueCoder) { |
| BagUserState<T> rval = |
| new BagUserState<>( |
| getCacheFor(stateKey), |
| beamFnStateClient, |
| processBundleInstructionId.get(), |
| stateKey, |
| valueCoder); |
| stateFinalizers.add(rval::asyncClose); |
| return rval; |
| } |
| |
| private StateKey createBagUserStateKey(String stateId) { |
| StateKey.Builder builder = StateKey.newBuilder(); |
| builder |
| .getBagUserStateBuilder() |
| .setWindow(encodedCurrentWindowSupplier.get()) |
| .setKey(encodedCurrentKeySupplier.get()) |
| .setTransformId(ptransformId) |
| .setUserStateId(stateId); |
| return builder.build(); |
| } |
| |
| private <KeyT, ValueT> MultimapUserState<KeyT, ValueT> createMultimapUserState( |
| StateKey stateKey, Coder<KeyT> keyCoder, Coder<ValueT> valueCoder) { |
| MultimapUserState<KeyT, ValueT> rval = |
| new MultimapUserState( |
| getCacheFor(stateKey), |
| beamFnStateClient, |
| processBundleInstructionId.get(), |
| stateKey, |
| keyCoder, |
| valueCoder); |
| stateFinalizers.add(rval::asyncClose); |
| return rval; |
| } |
| |
| private StateKey createMultimapKeysUserStateKey(String stateId) { |
| StateKey.Builder builder = StateKey.newBuilder(); |
| builder |
| .getMultimapKeysUserStateBuilder() |
| .setWindow(encodedCurrentWindowSupplier.get()) |
| .setTransformId(ptransformId) |
| .setUserStateId(stateId); |
| return builder.build(); |
| } |
| |
| public void finalizeState() { |
| // Persist all dirty state cells |
| try { |
| for (ThrowingRunnable runnable : stateFinalizers) { |
| runnable.run(); |
| } |
| } catch (InterruptedException e) { |
| Thread.currentThread().interrupt(); |
| throw new IllegalStateException(e); |
| } catch (Exception e) { |
| throw new IllegalStateException(e); |
| } |
| stateFinalizers.clear(); |
| stateKeyObjectCache.clear(); |
| } |
| } |